mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge remote-tracking branch 'upstream/master' into addBatchIndex
This commit is contained in:
commit
89f3d2ea64
14
README.md
14
README.md
@ -87,13 +87,13 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
|||||||
|
|
||||||
Put your VAE in: models/vae
|
Put your VAE in: models/vae
|
||||||
|
|
||||||
At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10.
|
|
||||||
|
|
||||||
### AMD GPUs (Linux only)
|
### AMD GPUs (Linux only)
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
||||||
|
|
||||||
|
This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements:
|
||||||
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
@ -178,16 +178,6 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||||||
|
|
||||||
```embedding:embedding_filename.pt```
|
```embedding:embedding_filename.pt```
|
||||||
|
|
||||||
### Fedora
|
|
||||||
|
|
||||||
To get python 3.10 on fedora:
|
|
||||||
```dnf install python3.10```
|
|
||||||
|
|
||||||
Then you can:
|
|
||||||
|
|
||||||
```python3.10 -m ensurepip```
|
|
||||||
|
|
||||||
This will let you use: pip3.10 to install all the dependencies.
|
|
||||||
|
|
||||||
## How to increase generation speed?
|
## How to increase generation speed?
|
||||||
|
|
||||||
|
|||||||
13
comfy/checkpoint_pickle.py
Normal file
13
comfy/checkpoint_pickle.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
|
load = pickle.load
|
||||||
|
|
||||||
|
class Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Unpickler(pickle.Unpickler):
|
||||||
|
def find_class(self, module, name):
|
||||||
|
#TODO: safe unpickle
|
||||||
|
if module.startswith("pytorch_lightning"):
|
||||||
|
return Empty
|
||||||
|
return super().find_class(module, name)
|
||||||
@ -14,8 +14,7 @@ from ..ldm.modules.diffusionmodules.util import (
|
|||||||
|
|
||||||
from ..ldm.modules.attention import SpatialTransformer
|
from ..ldm.modules.attention import SpatialTransformer
|
||||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||||
from ..ldm.models.diffusion.ddpm import LatentDiffusion
|
from ..ldm.util import exists
|
||||||
from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
|
|
||||||
|
|
||||||
|
|
||||||
class ControlledUnetModel(UNetModel):
|
class ControlledUnetModel(UNetModel):
|
||||||
|
|||||||
@ -59,12 +59,14 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he
|
|||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
|
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||||
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
|
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
|
||||||
from .utils import load_torch_file, transformers_convert
|
from .utils import load_torch_file, transformers_convert
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
config = CLIPVisionConfig.from_json_file(json_config)
|
config = CLIPVisionConfig.from_json_file(json_config)
|
||||||
self.model = CLIPVisionModelWithProjection(config)
|
with comfy.ops.use_comfy_ops():
|
||||||
|
with modeling_utils.no_init_weights():
|
||||||
|
self.model = CLIPVisionModelWithProjection(config)
|
||||||
self.processor = CLIPImageProcessor(crop_size=224,
|
self.processor = CLIPImageProcessor(crop_size=224,
|
||||||
do_center_crop=True,
|
do_center_crop=True,
|
||||||
do_convert_rgb=True,
|
do_convert_rgb=True,
|
||||||
@ -18,7 +21,7 @@ class ClipVisionModel():
|
|||||||
size=224)
|
size=224)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
self.model.load_state_dict(sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def encode_image(self, image):
|
def encode_image(self, image):
|
||||||
img = torch.clip((255. * image[0]), 0, 255).round().int()
|
img = torch.clip((255. * image[0]), 0, 255).round().int()
|
||||||
@ -56,7 +59,13 @@ def load_clipvision_from_sd(sd):
|
|||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
clip = ClipVisionModel(json_config)
|
clip = ClipVisionModel(json_config)
|
||||||
clip.load_sd(sd)
|
m, u = clip.load_sd(sd)
|
||||||
|
u = set(u)
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
if k not in u:
|
||||||
|
t = sd.pop(k)
|
||||||
|
del t
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load(ckpt_path):
|
def load(ckpt_path):
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import os
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.ldm.util import instantiate_from_config
|
|
||||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
|
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
|
|||||||
@ -260,7 +260,8 @@ class Gligen(nn.Module):
|
|||||||
return r
|
return r
|
||||||
return func_lowvram
|
return func_lowvram
|
||||||
else:
|
else:
|
||||||
def func(key, x):
|
def func(x, extra_options):
|
||||||
|
key = extra_options["transformer_index"]
|
||||||
module = self.module_list[key]
|
module = self.module_list[key]
|
||||||
return module(x, objs)
|
return module(x, objs)
|
||||||
return func
|
return func
|
||||||
|
|||||||
@ -1,105 +0,0 @@
|
|||||||
from functools import reduce
|
|
||||||
import math
|
|
||||||
import operator
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from skimage import transform
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
def translate2d(tx, ty):
|
|
||||||
mat = [[1, 0, tx],
|
|
||||||
[0, 1, ty],
|
|
||||||
[0, 0, 1]]
|
|
||||||
return torch.tensor(mat, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def scale2d(sx, sy):
|
|
||||||
mat = [[sx, 0, 0],
|
|
||||||
[ 0, sy, 0],
|
|
||||||
[ 0, 0, 1]]
|
|
||||||
return torch.tensor(mat, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate2d(theta):
|
|
||||||
mat = [[torch.cos(theta), torch.sin(-theta), 0],
|
|
||||||
[torch.sin(theta), torch.cos(theta), 0],
|
|
||||||
[ 0, 0, 1]]
|
|
||||||
return torch.tensor(mat, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
class KarrasAugmentationPipeline:
|
|
||||||
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
|
|
||||||
self.a_prob = a_prob
|
|
||||||
self.a_scale = a_scale
|
|
||||||
self.a_aniso = a_aniso
|
|
||||||
self.a_trans = a_trans
|
|
||||||
|
|
||||||
def __call__(self, image):
|
|
||||||
h, w = image.size
|
|
||||||
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
|
|
||||||
|
|
||||||
# x-flip
|
|
||||||
a0 = torch.randint(2, []).float()
|
|
||||||
mats.append(scale2d(1 - 2 * a0, 1))
|
|
||||||
# y-flip
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a1 = torch.randint(2, []).float() * do
|
|
||||||
mats.append(scale2d(1, 1 - 2 * a1))
|
|
||||||
# scaling
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a2 = torch.randn([]) * do
|
|
||||||
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
|
|
||||||
# rotation
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
|
||||||
mats.append(rotate2d(-a3))
|
|
||||||
# anisotropy
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
|
||||||
a5 = torch.randn([]) * do
|
|
||||||
mats.append(rotate2d(a4))
|
|
||||||
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
|
|
||||||
mats.append(rotate2d(-a4))
|
|
||||||
# translation
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a6 = torch.randn([]) * do
|
|
||||||
a7 = torch.randn([]) * do
|
|
||||||
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
|
|
||||||
|
|
||||||
# form the transformation matrix and conditioning vector
|
|
||||||
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
|
|
||||||
mat = reduce(operator.matmul, mats)
|
|
||||||
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
|
|
||||||
|
|
||||||
# apply the transformation
|
|
||||||
image_orig = np.array(image, dtype=np.float32) / 255
|
|
||||||
if image_orig.ndim == 2:
|
|
||||||
image_orig = image_orig[..., None]
|
|
||||||
tf = transform.AffineTransform(mat.numpy())
|
|
||||||
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
|
|
||||||
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
|
|
||||||
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
|
|
||||||
return image, image_orig, cond
|
|
||||||
|
|
||||||
|
|
||||||
class KarrasAugmentWrapper(nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
|
|
||||||
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
|
|
||||||
if aug_cond is None:
|
|
||||||
aug_cond = input.new_zeros([input.shape[0], 9])
|
|
||||||
if mapping_cond is None:
|
|
||||||
mapping_cond = aug_cond
|
|
||||||
else:
|
|
||||||
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
|
|
||||||
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
|
|
||||||
|
|
||||||
def set_skip_stages(self, skip_stages):
|
|
||||||
return self.inner_model.set_skip_stages(skip_stages)
|
|
||||||
|
|
||||||
def set_patch_size(self, patch_size):
|
|
||||||
return self.inner_model.set_patch_size(patch_size)
|
|
||||||
@ -1,110 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from jsonmerge import merge
|
|
||||||
|
|
||||||
from . import augmentation, layers, models, utils
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(file):
|
|
||||||
defaults = {
|
|
||||||
'model': {
|
|
||||||
'sigma_data': 1.,
|
|
||||||
'patch_size': 1,
|
|
||||||
'dropout_rate': 0.,
|
|
||||||
'augment_wrapper': True,
|
|
||||||
'augment_prob': 0.,
|
|
||||||
'mapping_cond_dim': 0,
|
|
||||||
'unet_cond_dim': 0,
|
|
||||||
'cross_cond_dim': 0,
|
|
||||||
'cross_attn_depths': None,
|
|
||||||
'skip_stages': 0,
|
|
||||||
'has_variance': False,
|
|
||||||
},
|
|
||||||
'dataset': {
|
|
||||||
'type': 'imagefolder',
|
|
||||||
},
|
|
||||||
'optimizer': {
|
|
||||||
'type': 'adamw',
|
|
||||||
'lr': 1e-4,
|
|
||||||
'betas': [0.95, 0.999],
|
|
||||||
'eps': 1e-6,
|
|
||||||
'weight_decay': 1e-3,
|
|
||||||
},
|
|
||||||
'lr_sched': {
|
|
||||||
'type': 'inverse',
|
|
||||||
'inv_gamma': 20000.,
|
|
||||||
'power': 1.,
|
|
||||||
'warmup': 0.99,
|
|
||||||
},
|
|
||||||
'ema_sched': {
|
|
||||||
'type': 'inverse',
|
|
||||||
'power': 0.6667,
|
|
||||||
'max_value': 0.9999
|
|
||||||
},
|
|
||||||
}
|
|
||||||
config = json.load(file)
|
|
||||||
return merge(defaults, config)
|
|
||||||
|
|
||||||
|
|
||||||
def make_model(config):
|
|
||||||
config = config['model']
|
|
||||||
assert config['type'] == 'image_v1'
|
|
||||||
model = models.ImageDenoiserModelV1(
|
|
||||||
config['input_channels'],
|
|
||||||
config['mapping_out'],
|
|
||||||
config['depths'],
|
|
||||||
config['channels'],
|
|
||||||
config['self_attn_depths'],
|
|
||||||
config['cross_attn_depths'],
|
|
||||||
patch_size=config['patch_size'],
|
|
||||||
dropout_rate=config['dropout_rate'],
|
|
||||||
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
|
|
||||||
unet_cond_dim=config['unet_cond_dim'],
|
|
||||||
cross_cond_dim=config['cross_cond_dim'],
|
|
||||||
skip_stages=config['skip_stages'],
|
|
||||||
has_variance=config['has_variance'],
|
|
||||||
)
|
|
||||||
if config['augment_wrapper']:
|
|
||||||
model = augmentation.KarrasAugmentWrapper(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def make_denoiser_wrapper(config):
|
|
||||||
config = config['model']
|
|
||||||
sigma_data = config.get('sigma_data', 1.)
|
|
||||||
has_variance = config.get('has_variance', False)
|
|
||||||
if not has_variance:
|
|
||||||
return partial(layers.Denoiser, sigma_data=sigma_data)
|
|
||||||
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
|
|
||||||
|
|
||||||
|
|
||||||
def make_sample_density(config):
|
|
||||||
sd_config = config['sigma_sample_density']
|
|
||||||
sigma_data = config['sigma_data']
|
|
||||||
if sd_config['type'] == 'lognormal':
|
|
||||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
|
|
||||||
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
|
|
||||||
return partial(utils.rand_log_normal, loc=loc, scale=scale)
|
|
||||||
if sd_config['type'] == 'loglogistic':
|
|
||||||
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
|
|
||||||
scale = sd_config['scale'] if 'scale' in sd_config else 0.5
|
|
||||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
|
|
||||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
|
|
||||||
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
|
|
||||||
if sd_config['type'] == 'loguniform':
|
|
||||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
|
|
||||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
|
|
||||||
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
|
|
||||||
if sd_config['type'] == 'v-diffusion':
|
|
||||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
|
|
||||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
|
|
||||||
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
|
|
||||||
if sd_config['type'] == 'split-lognormal':
|
|
||||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
|
|
||||||
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
|
|
||||||
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
|
|
||||||
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
|
|
||||||
raise ValueError('Unknown sample density type')
|
|
||||||
@ -1,134 +0,0 @@
|
|||||||
import math
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from cleanfid.inception_torchscript import InceptionV3W
|
|
||||||
import clip
|
|
||||||
from resize_right import resize
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import trange
|
|
||||||
|
|
||||||
from . import utils
|
|
||||||
|
|
||||||
|
|
||||||
class InceptionV3FeatureExtractor(nn.Module):
|
|
||||||
def __init__(self, device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
|
|
||||||
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
|
||||||
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
|
|
||||||
utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
|
|
||||||
self.model = InceptionV3W(str(path), resize_inside=False).to(device)
|
|
||||||
self.size = (299, 299)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if x.shape[2:4] != self.size:
|
|
||||||
x = resize(x, out_shape=self.size, pad_mode='reflect')
|
|
||||||
if x.shape[1] == 1:
|
|
||||||
x = torch.cat([x] * 3, dim=1)
|
|
||||||
x = (x * 127.5 + 127.5).clamp(0, 255)
|
|
||||||
return self.model(x)
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPFeatureExtractor(nn.Module):
|
|
||||||
def __init__(self, name='ViT-L/14@336px', device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
|
|
||||||
self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
|
|
||||||
std=(0.26862954, 0.26130258, 0.27577711))
|
|
||||||
self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if x.shape[2:4] != self.size:
|
|
||||||
x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
|
|
||||||
x = self.normalize(x)
|
|
||||||
x = self.model.encode_image(x).float()
|
|
||||||
x = F.normalize(x) * x.shape[1] ** 0.5
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
|
|
||||||
n_per_proc = math.ceil(n / accelerator.num_processes)
|
|
||||||
feats_all = []
|
|
||||||
try:
|
|
||||||
for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
|
|
||||||
cur_batch_size = min(n - i, batch_size)
|
|
||||||
samples = sample_fn(cur_batch_size)[:cur_batch_size]
|
|
||||||
feats_all.append(accelerator.gather(extractor_fn(samples)))
|
|
||||||
except StopIteration:
|
|
||||||
pass
|
|
||||||
return torch.cat(feats_all)[:n]
|
|
||||||
|
|
||||||
|
|
||||||
def polynomial_kernel(x, y):
|
|
||||||
d = x.shape[-1]
|
|
||||||
dot = x @ y.transpose(-2, -1)
|
|
||||||
return (dot / d + 1) ** 3
|
|
||||||
|
|
||||||
|
|
||||||
def squared_mmd(x, y, kernel=polynomial_kernel):
|
|
||||||
m = x.shape[-2]
|
|
||||||
n = y.shape[-2]
|
|
||||||
kxx = kernel(x, x)
|
|
||||||
kyy = kernel(y, y)
|
|
||||||
kxy = kernel(x, y)
|
|
||||||
kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
|
|
||||||
kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
|
|
||||||
kxy_sum = kxy.sum([-1, -2])
|
|
||||||
term_1 = kxx_sum / m / (m - 1)
|
|
||||||
term_2 = kyy_sum / n / (n - 1)
|
|
||||||
term_3 = kxy_sum * 2 / m / n
|
|
||||||
return term_1 + term_2 - term_3
|
|
||||||
|
|
||||||
|
|
||||||
@utils.tf32_mode(matmul=False)
|
|
||||||
def kid(x, y, max_size=5000):
|
|
||||||
x_size, y_size = x.shape[0], y.shape[0]
|
|
||||||
n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
|
|
||||||
total_mmd = x.new_zeros([])
|
|
||||||
for i in range(n_partitions):
|
|
||||||
cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
|
|
||||||
cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
|
|
||||||
total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
|
|
||||||
return total_mmd / n_partitions
|
|
||||||
|
|
||||||
|
|
||||||
class _MatrixSquareRootEig(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, a):
|
|
||||||
vals, vecs = torch.linalg.eigh(a)
|
|
||||||
ctx.save_for_backward(vals, vecs)
|
|
||||||
return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
vals, vecs = ctx.saved_tensors
|
|
||||||
d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
|
|
||||||
vecs_t = vecs.transpose(-2, -1)
|
|
||||||
return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
|
|
||||||
|
|
||||||
|
|
||||||
def sqrtm_eig(a):
|
|
||||||
if a.ndim < 2:
|
|
||||||
raise RuntimeError('tensor of matrices must have at least 2 dimensions')
|
|
||||||
if a.shape[-2] != a.shape[-1]:
|
|
||||||
raise RuntimeError('tensor must be batches of square matrices')
|
|
||||||
return _MatrixSquareRootEig.apply(a)
|
|
||||||
|
|
||||||
|
|
||||||
@utils.tf32_mode(matmul=False)
|
|
||||||
def fid(x, y, eps=1e-8):
|
|
||||||
x_mean = x.mean(dim=0)
|
|
||||||
y_mean = y.mean(dim=0)
|
|
||||||
mean_term = (x_mean - y_mean).pow(2).sum()
|
|
||||||
x_cov = torch.cov(x.T)
|
|
||||||
y_cov = torch.cov(y.T)
|
|
||||||
eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
|
|
||||||
x_cov = x_cov + eps_eye
|
|
||||||
y_cov = y_cov + eps_eye
|
|
||||||
x_cov_sqrt = sqrtm_eig(x_cov)
|
|
||||||
cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
|
|
||||||
return mean_term + cov_term
|
|
||||||
@ -1,99 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class DDPGradientStatsHook:
|
|
||||||
def __init__(self, ddp_module):
|
|
||||||
try:
|
|
||||||
ddp_module.register_comm_hook(self, self._hook_fn)
|
|
||||||
except AttributeError:
|
|
||||||
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
|
|
||||||
self._clear_state()
|
|
||||||
|
|
||||||
def _clear_state(self):
|
|
||||||
self.bucket_sq_norms_small_batch = []
|
|
||||||
self.bucket_sq_norms_large_batch = []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _hook_fn(self, bucket):
|
|
||||||
buf = bucket.buffer()
|
|
||||||
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
|
|
||||||
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
|
|
||||||
def callback(fut):
|
|
||||||
buf = fut.value()[0]
|
|
||||||
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
|
|
||||||
return buf
|
|
||||||
return fut.then(callback)
|
|
||||||
|
|
||||||
def get_stats(self):
|
|
||||||
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
|
|
||||||
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
|
|
||||||
self._clear_state()
|
|
||||||
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
|
|
||||||
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
|
|
||||||
return stats[0].item(), stats[1].item()
|
|
||||||
|
|
||||||
|
|
||||||
class GradientNoiseScale:
|
|
||||||
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
|
|
||||||
from _An Empirical Model of Large-Batch Training_,
|
|
||||||
https://arxiv.org/abs/1812.06162).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
beta (float): The decay factor for the exponential moving averages used to
|
|
||||||
calculate the gradient noise scale.
|
|
||||||
Default: 0.9998
|
|
||||||
eps (float): Added for numerical stability.
|
|
||||||
Default: 1e-8
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, beta=0.9998, eps=1e-8):
|
|
||||||
self.beta = beta
|
|
||||||
self.eps = eps
|
|
||||||
self.ema_sq_norm = 0.
|
|
||||||
self.ema_var = 0.
|
|
||||||
self.beta_cumprod = 1.
|
|
||||||
self.gradient_noise_scale = float('nan')
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
"""Returns the state of the object as a :class:`dict`."""
|
|
||||||
return dict(self.__dict__.items())
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
"""Loads the object's state.
|
|
||||||
Args:
|
|
||||||
state_dict (dict): object state. Should be an object returned
|
|
||||||
from a call to :meth:`state_dict`.
|
|
||||||
"""
|
|
||||||
self.__dict__.update(state_dict)
|
|
||||||
|
|
||||||
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
|
|
||||||
"""Updates the state with a new batch's gradient statistics, and returns the
|
|
||||||
current gradient noise scale.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
|
|
||||||
per sample gradients.
|
|
||||||
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
|
|
||||||
per sample gradients.
|
|
||||||
n_small_batch (int): The batch size of the individual microbatch or per sample
|
|
||||||
gradients (1 if per sample).
|
|
||||||
n_large_batch (int): The total batch size of the mean of the microbatch or
|
|
||||||
per sample gradients.
|
|
||||||
"""
|
|
||||||
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
|
|
||||||
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
|
|
||||||
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
|
|
||||||
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
|
|
||||||
self.beta_cumprod *= self.beta
|
|
||||||
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
|
|
||||||
return self.gradient_noise_scale
|
|
||||||
|
|
||||||
def get_gns(self):
|
|
||||||
"""Returns the current gradient noise scale."""
|
|
||||||
return self.gradient_noise_scale
|
|
||||||
|
|
||||||
def get_stats(self):
|
|
||||||
"""Returns the current (debiased) estimates of the squared mean gradient
|
|
||||||
and gradient variance."""
|
|
||||||
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
|
|
||||||
@ -1,246 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from . import utils
|
|
||||||
|
|
||||||
# Karras et al. preconditioned denoiser
|
|
||||||
|
|
||||||
class Denoiser(nn.Module):
|
|
||||||
"""A Karras et al. preconditioner for denoising diffusion models."""
|
|
||||||
|
|
||||||
def __init__(self, inner_model, sigma_data=1.):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = inner_model
|
|
||||||
self.sigma_data = sigma_data
|
|
||||||
|
|
||||||
def get_scalings(self, sigma):
|
|
||||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
|
||||||
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
return c_skip, c_out, c_in
|
|
||||||
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
|
|
||||||
target = (input - c_skip * noised_input) / c_out
|
|
||||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
|
|
||||||
|
|
||||||
|
|
||||||
class DenoiserWithVariance(Denoiser):
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
|
|
||||||
logvar = utils.append_dims(logvar, model_output.ndim)
|
|
||||||
target = (input - c_skip * noised_input) / c_out
|
|
||||||
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
|
|
||||||
return losses.flatten(1).mean(1)
|
|
||||||
|
|
||||||
|
|
||||||
# Residual blocks
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
|
||||||
def __init__(self, *main, skip=None):
|
|
||||||
super().__init__()
|
|
||||||
self.main = nn.Sequential(*main)
|
|
||||||
self.skip = skip if skip else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return self.main(input) + self.skip(input)
|
|
||||||
|
|
||||||
|
|
||||||
# Noise level (and other) conditioning
|
|
||||||
|
|
||||||
class ConditionedModule(nn.Module):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UnconditionedModule(ConditionedModule):
|
|
||||||
def __init__(self, module):
|
|
||||||
super().__init__()
|
|
||||||
self.module = module
|
|
||||||
|
|
||||||
def forward(self, input, cond=None):
|
|
||||||
return self.module(input)
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionedSequential(nn.Sequential, ConditionedModule):
|
|
||||||
def forward(self, input, cond):
|
|
||||||
for module in self:
|
|
||||||
if isinstance(module, ConditionedModule):
|
|
||||||
input = module(input, cond)
|
|
||||||
else:
|
|
||||||
input = module(input)
|
|
||||||
return input
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionedResidualBlock(ConditionedModule):
|
|
||||||
def __init__(self, *main, skip=None):
|
|
||||||
super().__init__()
|
|
||||||
self.main = ConditionedSequential(*main)
|
|
||||||
self.skip = skip if skip else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
|
|
||||||
return self.main(input, cond) + skip
|
|
||||||
|
|
||||||
|
|
||||||
class AdaGN(ConditionedModule):
|
|
||||||
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
|
|
||||||
super().__init__()
|
|
||||||
self.num_groups = num_groups
|
|
||||||
self.eps = eps
|
|
||||||
self.cond_key = cond_key
|
|
||||||
self.mapper = nn.Linear(feats_in, c_out * 2)
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
|
|
||||||
input = F.group_norm(input, self.num_groups, eps=self.eps)
|
|
||||||
return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
|
|
||||||
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
|
|
||||||
class SelfAttention2d(ConditionedModule):
|
|
||||||
def __init__(self, c_in, n_head, norm, dropout_rate=0.):
|
|
||||||
super().__init__()
|
|
||||||
assert c_in % n_head == 0
|
|
||||||
self.norm_in = norm(c_in)
|
|
||||||
self.n_head = n_head
|
|
||||||
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
|
|
||||||
self.out_proj = nn.Conv2d(c_in, c_in, 1)
|
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
n, c, h, w = input.shape
|
|
||||||
qkv = self.qkv_proj(self.norm_in(input, cond))
|
|
||||||
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
scale = k.shape[3] ** -0.25
|
|
||||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
|
||||||
att = self.dropout(att)
|
|
||||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
|
|
||||||
return input + self.out_proj(y)
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention2d(ConditionedModule):
|
|
||||||
def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
|
|
||||||
cond_key='cross', cond_key_padding='cross_padding'):
|
|
||||||
super().__init__()
|
|
||||||
assert c_dec % n_head == 0
|
|
||||||
self.cond_key = cond_key
|
|
||||||
self.cond_key_padding = cond_key_padding
|
|
||||||
self.norm_enc = nn.LayerNorm(c_enc)
|
|
||||||
self.norm_dec = norm_dec(c_dec)
|
|
||||||
self.n_head = n_head
|
|
||||||
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
|
|
||||||
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
|
|
||||||
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
|
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
n, c, h, w = input.shape
|
|
||||||
q = self.q_proj(self.norm_dec(input, cond))
|
|
||||||
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
|
|
||||||
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
|
|
||||||
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
|
|
||||||
k, v = kv.chunk(2, dim=1)
|
|
||||||
scale = k.shape[3] ** -0.25
|
|
||||||
att = ((q * scale) @ (k.transpose(2, 3) * scale))
|
|
||||||
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
|
|
||||||
att = att.softmax(3)
|
|
||||||
att = self.dropout(att)
|
|
||||||
y = (att @ v).transpose(2, 3)
|
|
||||||
y = y.contiguous().view([n, c, h, w])
|
|
||||||
return input + self.out_proj(y)
|
|
||||||
|
|
||||||
|
|
||||||
# Downsampling/upsampling
|
|
||||||
|
|
||||||
_kernels = {
|
|
||||||
'linear':
|
|
||||||
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
|
||||||
'cubic':
|
|
||||||
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
|
||||||
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
|
||||||
'lanczos3':
|
|
||||||
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
|
||||||
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
|
||||||
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
|
||||||
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
|
||||||
}
|
|
||||||
_kernels['bilinear'] = _kernels['linear']
|
|
||||||
_kernels['bicubic'] = _kernels['cubic']
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample2d(nn.Module):
|
|
||||||
def __init__(self, kernel='linear', pad_mode='reflect'):
|
|
||||||
super().__init__()
|
|
||||||
self.pad_mode = pad_mode
|
|
||||||
kernel_1d = torch.tensor([_kernels[kernel]])
|
|
||||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
|
||||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
|
|
||||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
|
||||||
indices = torch.arange(x.shape[1], device=x.device)
|
|
||||||
weight[indices, indices] = self.kernel.to(weight)
|
|
||||||
return F.conv2d(x, weight, stride=2)
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample2d(nn.Module):
|
|
||||||
def __init__(self, kernel='linear', pad_mode='reflect'):
|
|
||||||
super().__init__()
|
|
||||||
self.pad_mode = pad_mode
|
|
||||||
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
|
|
||||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
|
||||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
|
||||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
|
||||||
indices = torch.arange(x.shape[1], device=x.device)
|
|
||||||
weight[indices, indices] = self.kernel.to(weight)
|
|
||||||
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
|
||||||
|
|
||||||
|
|
||||||
# Embeddings
|
|
||||||
|
|
||||||
class FourierFeatures(nn.Module):
|
|
||||||
def __init__(self, in_features, out_features, std=1.):
|
|
||||||
super().__init__()
|
|
||||||
assert out_features % 2 == 0
|
|
||||||
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
f = 2 * math.pi * input @ self.weight.T
|
|
||||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
# U-Nets
|
|
||||||
|
|
||||||
class UNet(ConditionedModule):
|
|
||||||
def __init__(self, d_blocks, u_blocks, skip_stages=0):
|
|
||||||
super().__init__()
|
|
||||||
self.d_blocks = nn.ModuleList(d_blocks)
|
|
||||||
self.u_blocks = nn.ModuleList(u_blocks)
|
|
||||||
self.skip_stages = skip_stages
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
skips = []
|
|
||||||
for block in self.d_blocks[self.skip_stages:]:
|
|
||||||
input = block(input, cond)
|
|
||||||
skips.append(input)
|
|
||||||
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
|
|
||||||
input = block(input, cond, skip if i > 0 else None)
|
|
||||||
return input
|
|
||||||
@ -1 +0,0 @@
|
|||||||
from .image_v1 import ImageDenoiserModelV1
|
|
||||||
@ -1,156 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .. import layers, utils
|
|
||||||
|
|
||||||
|
|
||||||
def orthogonal_(module):
|
|
||||||
nn.init.orthogonal_(module.weight)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
class ResConvBlock(layers.ConditionedResidualBlock):
|
|
||||||
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
|
|
||||||
skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
|
|
||||||
super().__init__(
|
|
||||||
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(c_in, c_mid, 3, padding=1),
|
|
||||||
nn.Dropout2d(dropout_rate, inplace=True),
|
|
||||||
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(c_mid, c_out, 3, padding=1),
|
|
||||||
nn.Dropout2d(dropout_rate, inplace=True),
|
|
||||||
skip=skip)
|
|
||||||
|
|
||||||
|
|
||||||
class DBlock(layers.ConditionedSequential):
|
|
||||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
|
|
||||||
modules = [nn.Identity()]
|
|
||||||
for i in range(n_layers):
|
|
||||||
my_c_in = c_in if i == 0 else c_mid
|
|
||||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
|
||||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
|
|
||||||
if self_attn:
|
|
||||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
|
||||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
|
|
||||||
if cross_attn:
|
|
||||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
|
||||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
|
|
||||||
super().__init__(*modules)
|
|
||||||
self.set_downsample(downsample)
|
|
||||||
|
|
||||||
def set_downsample(self, downsample):
|
|
||||||
self[0] = layers.Downsample2d() if downsample else nn.Identity()
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class UBlock(layers.ConditionedSequential):
|
|
||||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
|
|
||||||
modules = []
|
|
||||||
for i in range(n_layers):
|
|
||||||
my_c_in = c_in if i == 0 else c_mid
|
|
||||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
|
||||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
|
|
||||||
if self_attn:
|
|
||||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
|
||||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
|
|
||||||
if cross_attn:
|
|
||||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
|
||||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
|
|
||||||
modules.append(nn.Identity())
|
|
||||||
super().__init__(*modules)
|
|
||||||
self.set_upsample(upsample)
|
|
||||||
|
|
||||||
def forward(self, input, cond, skip=None):
|
|
||||||
if skip is not None:
|
|
||||||
input = torch.cat([input, skip], dim=1)
|
|
||||||
return super().forward(input, cond)
|
|
||||||
|
|
||||||
def set_upsample(self, upsample):
|
|
||||||
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class MappingNet(nn.Sequential):
|
|
||||||
def __init__(self, feats_in, feats_out, n_layers=2):
|
|
||||||
layers = []
|
|
||||||
for i in range(n_layers):
|
|
||||||
layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)))
|
|
||||||
layers.append(nn.GELU())
|
|
||||||
super().__init__(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDenoiserModelV1(nn.Module):
|
|
||||||
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False):
|
|
||||||
super().__init__()
|
|
||||||
self.c_in = c_in
|
|
||||||
self.channels = channels
|
|
||||||
self.unet_cond_dim = unet_cond_dim
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.has_variance = has_variance
|
|
||||||
self.timestep_embed = layers.FourierFeatures(1, feats_in)
|
|
||||||
if mapping_cond_dim > 0:
|
|
||||||
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
|
|
||||||
self.mapping = MappingNet(feats_in, feats_in)
|
|
||||||
self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
|
|
||||||
self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
|
|
||||||
nn.init.zeros_(self.proj_out.weight)
|
|
||||||
nn.init.zeros_(self.proj_out.bias)
|
|
||||||
if cross_cond_dim == 0:
|
|
||||||
cross_attn_depths = [False] * len(self_attn_depths)
|
|
||||||
d_blocks, u_blocks = [], []
|
|
||||||
for i in range(len(depths)):
|
|
||||||
my_c_in = channels[max(0, i - 1)]
|
|
||||||
d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
|
|
||||||
for i in range(len(depths)):
|
|
||||||
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
|
|
||||||
my_c_out = channels[max(0, i - 1)]
|
|
||||||
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
|
|
||||||
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False):
|
|
||||||
c_noise = sigma.log() / 4
|
|
||||||
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
|
|
||||||
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
|
|
||||||
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
|
|
||||||
cond = {'cond': mapping_out}
|
|
||||||
if unet_cond is not None:
|
|
||||||
input = torch.cat([input, unet_cond], dim=1)
|
|
||||||
if cross_cond is not None:
|
|
||||||
cond['cross'] = cross_cond
|
|
||||||
cond['cross_padding'] = cross_cond_padding
|
|
||||||
if self.patch_size > 1:
|
|
||||||
input = F.pixel_unshuffle(input, self.patch_size)
|
|
||||||
input = self.proj_in(input)
|
|
||||||
input = self.u_net(input, cond)
|
|
||||||
input = self.proj_out(input)
|
|
||||||
if self.has_variance:
|
|
||||||
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
|
|
||||||
if self.patch_size > 1:
|
|
||||||
input = F.pixel_shuffle(input, self.patch_size)
|
|
||||||
if self.has_variance and return_variance:
|
|
||||||
return input, logvar
|
|
||||||
return input
|
|
||||||
|
|
||||||
def set_skip_stages(self, skip_stages):
|
|
||||||
self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
|
|
||||||
self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
|
|
||||||
nn.init.zeros_(self.proj_out.weight)
|
|
||||||
nn.init.zeros_(self.proj_out.bias)
|
|
||||||
self.u_net.skip_stages = skip_stages
|
|
||||||
for i, block in enumerate(self.u_net.d_blocks):
|
|
||||||
block.set_downsample(i > skip_stages)
|
|
||||||
for i, block in enumerate(reversed(self.u_net.u_blocks)):
|
|
||||||
block.set_upsample(i > skip_stages)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def set_patch_size(self, patch_size):
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
|
|
||||||
self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
|
|
||||||
nn.init.zeros_(self.proj_out.weight)
|
|
||||||
nn.init.zeros_(self.proj_out.bias)
|
|
||||||
@ -10,25 +10,6 @@ from PIL import Image
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
from torch.utils import data
|
from torch.utils import data
|
||||||
from torchvision.transforms import functional as TF
|
|
||||||
|
|
||||||
|
|
||||||
def from_pil_image(x):
|
|
||||||
"""Converts from a PIL image to a tensor."""
|
|
||||||
x = TF.to_tensor(x)
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = x[..., None]
|
|
||||||
return x * 2 - 1
|
|
||||||
|
|
||||||
|
|
||||||
def to_pil_image(x):
|
|
||||||
"""Converts from a tensor to a PIL image."""
|
|
||||||
if x.ndim == 4:
|
|
||||||
assert x.shape[0] == 1
|
|
||||||
x = x[0]
|
|
||||||
if x.shape[0] == 1:
|
|
||||||
x = x[0]
|
|
||||||
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
|
|
||||||
|
|
||||||
|
|
||||||
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
||||||
|
|||||||
@ -1,24 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from ldm.modules.midas.api import load_midas_transform
|
|
||||||
|
|
||||||
|
|
||||||
class AddMiDaS(object):
|
|
||||||
def __init__(self, model_type):
|
|
||||||
super().__init__()
|
|
||||||
self.transform = load_midas_transform(model_type)
|
|
||||||
|
|
||||||
def pt2np(self, x):
|
|
||||||
x = ((x + 1.0) * .5).detach().cpu().numpy()
|
|
||||||
return x
|
|
||||||
|
|
||||||
def np2pt(self, x):
|
|
||||||
x = torch.from_numpy(x) * 2 - 1.
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
|
||||||
x = self.pt2np(sample['jpg'])
|
|
||||||
x = self.transform({"image": x})["image"]
|
|
||||||
sample['midas_in'] = x
|
|
||||||
return sample
|
|
||||||
@ -284,7 +284,7 @@ class DDIMSampler(object):
|
|||||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||||
|
|
||||||
if self.model.parameterization == "v":
|
if self.model.parameterization == "v":
|
||||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||||||
else:
|
else:
|
||||||
e_t = model_output
|
e_t = model_output
|
||||||
|
|
||||||
@ -306,7 +306,7 @@ class DDIMSampler(object):
|
|||||||
if self.model.parameterization != "v":
|
if self.model.parameterization != "v":
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
else:
|
else:
|
||||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
||||||
|
|
||||||
if quantize_denoised:
|
if quantize_denoised:
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -10,6 +10,7 @@ from .diffusionmodules.util import checkpoint
|
|||||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
from . import tomesd
|
from . import tomesd
|
||||||
|
|
||||||
@ -50,9 +51,9 @@ def init_(tensor):
|
|||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out):
|
def __init__(self, dim_in, dim_out, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
@ -60,19 +61,19 @@ class GEGLU(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
project_in = nn.Sequential(
|
project_in = nn.Sequential(
|
||||||
nn.Linear(dim, inner_dim),
|
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
|
||||||
nn.GELU()
|
nn.GELU()
|
||||||
) if not glu else GEGLU(dim, inner_dim)
|
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
project_in,
|
project_in,
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(inner_dim, dim_out)
|
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -88,8 +89,8 @@ def zero_module(module):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels):
|
def Normalize(in_channels, dtype=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class SpatialSelfAttention(nn.Module):
|
class SpatialSelfAttention(nn.Module):
|
||||||
@ -146,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionBirchSan(nn.Module):
|
class CrossAttentionBirchSan(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -154,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, query_dim),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -243,7 +244,7 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionDoggettx(nn.Module):
|
class CrossAttentionDoggettx(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -251,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, query_dim),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -341,7 +342,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -349,12 +350,12 @@ class CrossAttention(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, query_dim),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -397,7 +398,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
class MemoryEfficientCrossAttention(nn.Module):
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||||
f"{heads} heads.")
|
f"{heads} heads.")
|
||||||
@ -407,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
@ -448,7 +449,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class CrossAttentionPytorch(nn.Module):
|
class CrossAttentionPytorch(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -456,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
@ -506,26 +507,30 @@ else:
|
|||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
disable_self_attn=False):
|
disable_self_attn=False, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||||
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype)
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
self.norm1 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.norm3 = nn.LayerNorm(dim)
|
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
def _forward(self, x, context=None, transformer_options={}):
|
def _forward(self, x, context=None, transformer_options={}):
|
||||||
current_index = None
|
extra_options = {}
|
||||||
if "current_index" in transformer_options:
|
if "current_index" in transformer_options:
|
||||||
current_index = transformer_options["current_index"]
|
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||||
|
if "block_index" in transformer_options:
|
||||||
|
extra_options["block_index"] = transformer_options["block_index"]
|
||||||
|
if "original_shape" in transformer_options:
|
||||||
|
extra_options["original_shape"] = transformer_options["original_shape"]
|
||||||
if "patches" in transformer_options:
|
if "patches" in transformer_options:
|
||||||
transformer_patches = transformer_options["patches"]
|
transformer_patches = transformer_options["patches"]
|
||||||
else:
|
else:
|
||||||
@ -544,7 +549,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
context_attn1 = n
|
context_attn1 = n
|
||||||
value_attn1 = context_attn1
|
value_attn1 = context_attn1
|
||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
if "tomesd" in transformer_options:
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
@ -556,7 +561,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
for p in patch:
|
for p in patch:
|
||||||
x = p(current_index, x)
|
x = p(x, extra_options)
|
||||||
|
|
||||||
n = self.norm2(x)
|
n = self.norm2(x)
|
||||||
|
|
||||||
@ -566,10 +571,15 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
patch = transformer_patches["attn2_patch"]
|
patch = transformer_patches["attn2_patch"]
|
||||||
value_attn2 = context_attn2
|
value_attn2 = context_attn2
|
||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||||
|
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
|
if "attn2_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn2_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
return x
|
return x
|
||||||
@ -587,35 +597,34 @@ class SpatialTransformer(nn.Module):
|
|||||||
def __init__(self, in_channels, n_heads, d_head,
|
def __init__(self, in_channels, n_heads, d_head,
|
||||||
depth=1, dropout=0., context_dim=None,
|
depth=1, dropout=0., context_dim=None,
|
||||||
disable_self_attn=False, use_linear=False,
|
disable_self_attn=False, use_linear=False,
|
||||||
use_checkpoint=True):
|
use_checkpoint=True, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if exists(context_dim) and not isinstance(context_dim, list):
|
if exists(context_dim) and not isinstance(context_dim, list):
|
||||||
context_dim = [context_dim]
|
context_dim = [context_dim]
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels, dtype=dtype)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_in = nn.Conv2d(in_channels,
|
self.proj_in = nn.Conv2d(in_channels,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype)
|
||||||
for d in range(depth)]
|
for d in range(depth)]
|
||||||
)
|
)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
self.proj_out = nn.Conv2d(inner_dim,in_channels,
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0))
|
padding=0, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
||||||
self.use_linear = use_linear
|
self.use_linear = use_linear
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
@ -631,6 +640,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|||||||
@ -111,14 +111,14 @@ class Upsample(nn.Module):
|
|||||||
upsampling occurs in the inner-two dimensions.
|
upsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, output_shape=None):
|
def forward(self, x, output_shape=None):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
@ -160,7 +160,7 @@ class Downsample(nn.Module):
|
|||||||
downsampling occurs in the inner-two dimensions.
|
downsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -169,7 +169,7 @@ class Downsample(nn.Module):
|
|||||||
stride = 2 if dims != 3 else (1, 2, 2)
|
stride = 2 if dims != 3 else (1, 2, 2)
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.op = conv_nd(
|
self.op = conv_nd(
|
||||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
@ -208,6 +208,7 @@ class ResBlock(TimestepBlock):
|
|||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
up=False,
|
up=False,
|
||||||
down=False,
|
down=False,
|
||||||
|
dtype=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@ -219,19 +220,19 @@ class ResBlock(TimestepBlock):
|
|||||||
self.use_scale_shift_norm = use_scale_shift_norm
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels),
|
normalization(channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.updown = up or down
|
self.updown = up or down
|
||||||
|
|
||||||
if up:
|
if up:
|
||||||
self.h_upd = Upsample(channels, False, dims)
|
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||||
self.x_upd = Upsample(channels, False, dims)
|
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||||
elif down:
|
elif down:
|
||||||
self.h_upd = Downsample(channels, False, dims)
|
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||||
self.x_upd = Downsample(channels, False, dims)
|
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.h_upd = self.x_upd = nn.Identity()
|
self.h_upd = self.x_upd = nn.Identity()
|
||||||
|
|
||||||
@ -239,15 +240,15 @@ class ResBlock(TimestepBlock):
|
|||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(
|
linear(
|
||||||
emb_channels,
|
emb_channels,
|
||||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
normalization(self.out_channels),
|
normalization(self.out_channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -255,10 +256,10 @@ class ResBlock(TimestepBlock):
|
|||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
self.skip_connection = conv_nd(
|
self.skip_connection = conv_nd(
|
||||||
dims, channels, self.out_channels, 3, padding=1
|
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
"""
|
"""
|
||||||
@ -558,9 +559,9 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(model_channels, time_embed_dim),
|
linear(model_channels, time_embed_dim, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
@ -573,9 +574,9 @@ class UNetModel(nn.Module):
|
|||||||
assert adm_in_channels is not None
|
assert adm_in_channels is not None
|
||||||
self.label_emb = nn.Sequential(
|
self.label_emb = nn.Sequential(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
linear(adm_in_channels, time_embed_dim),
|
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -584,7 +585,7 @@ class UNetModel(nn.Module):
|
|||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -603,6 +604,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
@ -631,7 +633,7 @@ class UNetModel(nn.Module):
|
|||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
@ -650,10 +652,11 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
down=True,
|
down=True,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -678,6 +681,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -688,7 +692,7 @@ class UNetModel(nn.Module):
|
|||||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -697,6 +701,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
@ -714,6 +719,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = model_channels * mult
|
||||||
@ -742,7 +748,7 @@ class UNetModel(nn.Module):
|
|||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if level and i == self.num_res_blocks[level]:
|
if level and i == self.num_res_blocks[level]:
|
||||||
@ -757,18 +763,19 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
up=True,
|
up=True,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
|
||||||
)
|
)
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(ch),
|
normalization(ch, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import numpy as np
|
|||||||
from einops import repeat
|
from einops import repeat
|
||||||
|
|
||||||
from comfy.ldm.util import instantiate_from_config
|
from comfy.ldm.util import instantiate_from_config
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
if schedule == "linear":
|
if schedule == "linear":
|
||||||
@ -206,13 +206,13 @@ def mean_flat(tensor):
|
|||||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||||
|
|
||||||
|
|
||||||
def normalization(channels):
|
def normalization(channels, dtype=None):
|
||||||
"""
|
"""
|
||||||
Make a standard normalization layer.
|
Make a standard normalization layer.
|
||||||
:param channels: number of input channels.
|
:param channels: number of input channels.
|
||||||
:return: an nn.Module for normalization.
|
:return: an nn.Module for normalization.
|
||||||
"""
|
"""
|
||||||
return GroupNorm32(32, channels)
|
return GroupNorm32(32, channels, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||||
@ -233,7 +233,7 @@ def conv_nd(dims, *args, **kwargs):
|
|||||||
if dims == 1:
|
if dims == 1:
|
||||||
return nn.Conv1d(*args, **kwargs)
|
return nn.Conv1d(*args, **kwargs)
|
||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
return nn.Conv2d(*args, **kwargs)
|
return comfy.ops.Conv2d(*args, **kwargs)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
return nn.Conv3d(*args, **kwargs)
|
return nn.Conv3d(*args, **kwargs)
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
@ -243,7 +243,7 @@ def linear(*args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
Create a linear module.
|
Create a linear module.
|
||||||
"""
|
"""
|
||||||
return nn.Linear(*args, **kwargs)
|
return comfy.ops.Linear(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def avg_pool_nd(dims, *args, **kwargs):
|
def avg_pool_nd(dims, *args, **kwargs):
|
||||||
|
|||||||
@ -1,59 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
from typing import List, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
|
|
||||||
|
|
||||||
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
|
||||||
r"""Normalize an image/video tensor with mean and standard deviation.
|
|
||||||
.. math::
|
|
||||||
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
|
|
||||||
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
|
|
||||||
Args:
|
|
||||||
data: Image tensor of size :math:`(B, C, *)`.
|
|
||||||
mean: Mean for each channel.
|
|
||||||
std: Standard deviations for each channel.
|
|
||||||
Return:
|
|
||||||
Normalised tensor with same size as input :math:`(B, C, *)`.
|
|
||||||
Examples:
|
|
||||||
>>> x = torch.rand(1, 4, 3, 3)
|
|
||||||
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
|
|
||||||
>>> out.shape
|
|
||||||
torch.Size([1, 4, 3, 3])
|
|
||||||
>>> x = torch.rand(1, 4, 3, 3)
|
|
||||||
>>> mean = torch.zeros(4)
|
|
||||||
>>> std = 255. * torch.ones(4)
|
|
||||||
>>> out = normalize(x, mean, std)
|
|
||||||
>>> out.shape
|
|
||||||
torch.Size([1, 4, 3, 3])
|
|
||||||
"""
|
|
||||||
shape = data.shape
|
|
||||||
if len(mean.shape) == 0 or mean.shape[0] == 1:
|
|
||||||
mean = mean.expand(shape[1])
|
|
||||||
if len(std.shape) == 0 or std.shape[0] == 1:
|
|
||||||
std = std.expand(shape[1])
|
|
||||||
|
|
||||||
# Allow broadcast on channel dimension
|
|
||||||
if mean.shape and mean.shape[0] != 1:
|
|
||||||
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
|
|
||||||
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
|
|
||||||
|
|
||||||
# Allow broadcast on channel dimension
|
|
||||||
if std.shape and std.shape[0] != 1:
|
|
||||||
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
|
|
||||||
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
|
|
||||||
|
|
||||||
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
|
|
||||||
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
|
|
||||||
|
|
||||||
if mean.shape:
|
|
||||||
mean = mean[..., :, None]
|
|
||||||
if std.shape:
|
|
||||||
std = std[..., :, None]
|
|
||||||
|
|
||||||
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
|
|
||||||
|
|
||||||
return out.view(shape)
|
|
||||||
@ -1,314 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from . import kornia_functions
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
|
||||||
|
|
||||||
import open_clip
|
|
||||||
from ldm.util import default, count_params
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractEncoder(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def encode(self, *args, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityEncoder(AbstractEncoder):
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ClassEmbedder(nn.Module):
|
|
||||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
|
||||||
super().__init__()
|
|
||||||
self.key = key
|
|
||||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
|
||||||
self.n_classes = n_classes
|
|
||||||
self.ucg_rate = ucg_rate
|
|
||||||
|
|
||||||
def forward(self, batch, key=None, disable_dropout=False):
|
|
||||||
if key is None:
|
|
||||||
key = self.key
|
|
||||||
# this is for use in crossattn
|
|
||||||
c = batch[key][:, None]
|
|
||||||
if self.ucg_rate > 0. and not disable_dropout:
|
|
||||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
|
||||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
|
||||||
c = c.long()
|
|
||||||
c = self.embedding(c)
|
|
||||||
return c
|
|
||||||
|
|
||||||
def get_unconditional_conditioning(self, bs, device="cuda"):
|
|
||||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
|
||||||
uc = torch.ones((bs,), device=device) * uc_class
|
|
||||||
uc = {self.key: uc}
|
|
||||||
return uc
|
|
||||||
|
|
||||||
|
|
||||||
def disabled_train(self, mode=True):
|
|
||||||
"""Overwrite model.train with this function to make sure train/eval mode
|
|
||||||
does not change anymore."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenT5Embedder(AbstractEncoder):
|
|
||||||
"""Uses the T5 transformer encoder for text"""
|
|
||||||
|
|
||||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
|
|
||||||
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
|
||||||
super().__init__()
|
|
||||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
|
||||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length # TODO: typical value?
|
|
||||||
if freeze:
|
|
||||||
self.freeze()
|
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
self.transformer = self.transformer.eval()
|
|
||||||
# self.train = disabled_train
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
|
||||||
outputs = self.transformer(input_ids=tokens)
|
|
||||||
|
|
||||||
z = outputs.last_hidden_state
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
|
||||||
LAYERS = [
|
|
||||||
"last",
|
|
||||||
"pooled",
|
|
||||||
"hidden"
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
|
||||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
|
||||||
super().__init__()
|
|
||||||
assert layer in self.LAYERS
|
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
|
||||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
|
||||||
if freeze:
|
|
||||||
self.freeze()
|
|
||||||
self.layer = layer
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
if layer == "hidden":
|
|
||||||
assert layer_idx is not None
|
|
||||||
assert 0 <= abs(layer_idx) <= 12
|
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
self.transformer = self.transformer.eval()
|
|
||||||
# self.train = disabled_train
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
|
|
||||||
if self.layer == "last":
|
|
||||||
z = outputs.last_hidden_state
|
|
||||||
elif self.layer == "pooled":
|
|
||||||
z = outputs.pooler_output[:, None, :]
|
|
||||||
else:
|
|
||||||
z = outputs.hidden_states[self.layer_idx]
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipImageEmbedder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
jit=False,
|
|
||||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
|
||||||
antialias=True,
|
|
||||||
ucg_rate=0.
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
from clip import load as load_clip
|
|
||||||
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
|
||||||
|
|
||||||
self.antialias = antialias
|
|
||||||
|
|
||||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
|
||||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
|
||||||
self.ucg_rate = ucg_rate
|
|
||||||
|
|
||||||
def preprocess(self, x):
|
|
||||||
# normalize to [0,1]
|
|
||||||
# x = kornia_functions.geometry_resize(x, (224, 224),
|
|
||||||
# interpolation='bicubic', align_corners=True,
|
|
||||||
# antialias=self.antialias)
|
|
||||||
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
|
||||||
x = (x + 1.) / 2.
|
|
||||||
# re-normalize according to clip
|
|
||||||
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x, no_dropout=False):
|
|
||||||
# x is assumed to be in range [-1,1]
|
|
||||||
out = self.model.encode_image(self.preprocess(x))
|
|
||||||
out = out.to(x.dtype)
|
|
||||||
if self.ucg_rate > 0. and not no_dropout:
|
|
||||||
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|
||||||
"""
|
|
||||||
Uses the OpenCLIP transformer encoder for text
|
|
||||||
"""
|
|
||||||
LAYERS = [
|
|
||||||
# "pooled",
|
|
||||||
"last",
|
|
||||||
"penultimate"
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
|
||||||
freeze=True, layer="last"):
|
|
||||||
super().__init__()
|
|
||||||
assert layer in self.LAYERS
|
|
||||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
|
|
||||||
del model.visual
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
|
||||||
if freeze:
|
|
||||||
self.freeze()
|
|
||||||
self.layer = layer
|
|
||||||
if self.layer == "last":
|
|
||||||
self.layer_idx = 0
|
|
||||||
elif self.layer == "penultimate":
|
|
||||||
self.layer_idx = 1
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
self.model = self.model.eval()
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
tokens = open_clip.tokenize(text)
|
|
||||||
z = self.encode_with_transformer(tokens.to(self.device))
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode_with_transformer(self, text):
|
|
||||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
|
||||||
x = x + self.model.positional_embedding
|
|
||||||
x = x.permute(1, 0, 2) # NLD -> LND
|
|
||||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
|
||||||
x = x.permute(1, 0, 2) # LND -> NLD
|
|
||||||
x = self.model.ln_final(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
|
||||||
for i, r in enumerate(self.model.transformer.resblocks):
|
|
||||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
|
||||||
break
|
|
||||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
|
||||||
x = checkpoint(r, x, attn_mask)
|
|
||||||
else:
|
|
||||||
x = r(x, attn_mask=attn_mask)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
|
||||||
"""
|
|
||||||
Uses the OpenCLIP vision transformer encoder for images
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
|
||||||
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
|
||||||
super().__init__()
|
|
||||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
|
||||||
pretrained=version, )
|
|
||||||
del model.transformer
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
|
||||||
if freeze:
|
|
||||||
self.freeze()
|
|
||||||
self.layer = layer
|
|
||||||
if self.layer == "penultimate":
|
|
||||||
raise NotImplementedError()
|
|
||||||
self.layer_idx = 1
|
|
||||||
|
|
||||||
self.antialias = antialias
|
|
||||||
|
|
||||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
|
||||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
|
||||||
self.ucg_rate = ucg_rate
|
|
||||||
|
|
||||||
def preprocess(self, x):
|
|
||||||
# normalize to [0,1]
|
|
||||||
# x = kornia.geometry.resize(x, (224, 224),
|
|
||||||
# interpolation='bicubic', align_corners=True,
|
|
||||||
# antialias=self.antialias)
|
|
||||||
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
|
||||||
x = (x + 1.) / 2.
|
|
||||||
# renormalize according to clip
|
|
||||||
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
self.model = self.model.eval()
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def forward(self, image, no_dropout=False):
|
|
||||||
z = self.encode_with_vision_transformer(image)
|
|
||||||
if self.ucg_rate > 0. and not no_dropout:
|
|
||||||
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
|
||||||
return z
|
|
||||||
|
|
||||||
def encode_with_vision_transformer(self, img):
|
|
||||||
img = self.preprocess(img)
|
|
||||||
x = self.model.visual(img)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
|
||||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
|
||||||
clip_max_length=77, t5_max_length=77):
|
|
||||||
super().__init__()
|
|
||||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
|
||||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
|
||||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
|
||||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
return self(text)
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
clip_z = self.clip_encoder.encode(text)
|
|
||||||
t5_z = self.t5_encoder.encode(text)
|
|
||||||
return [clip_z, t5_z]
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
|
|
||||||
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
|
|
||||||
@ -1,730 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Super-Resolution
|
|
||||||
# --------------------------------------------
|
|
||||||
#
|
|
||||||
# Kai Zhang (cskaizhang@gmail.com)
|
|
||||||
# https://github.com/cszn
|
|
||||||
# From 2019/03--2021/08
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
import random
|
|
||||||
from scipy import ndimage
|
|
||||||
import scipy
|
|
||||||
import scipy.stats as ss
|
|
||||||
from scipy.interpolate import interp2d
|
|
||||||
from scipy.linalg import orth
|
|
||||||
import albumentations
|
|
||||||
|
|
||||||
import ldm.modules.image_degradation.utils_image as util
|
|
||||||
|
|
||||||
|
|
||||||
def modcrop_np(img, sf):
|
|
||||||
'''
|
|
||||||
Args:
|
|
||||||
img: numpy image, WxH or WxHxC
|
|
||||||
sf: scale factor
|
|
||||||
Return:
|
|
||||||
cropped image
|
|
||||||
'''
|
|
||||||
w, h = img.shape[:2]
|
|
||||||
im = np.copy(img)
|
|
||||||
return im[:w - w % sf, :h - h % sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# anisotropic Gaussian kernels
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def analytic_kernel(k):
|
|
||||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
|
||||||
k_size = k.shape[0]
|
|
||||||
# Calculate the big kernels size
|
|
||||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
|
||||||
# Loop over the small kernel to fill the big one
|
|
||||||
for r in range(k_size):
|
|
||||||
for c in range(k_size):
|
|
||||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
|
||||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
|
||||||
crop = k_size // 2
|
|
||||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
|
||||||
# Normalize to 1
|
|
||||||
return cropped_big_k / cropped_big_k.sum()
|
|
||||||
|
|
||||||
|
|
||||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
|
||||||
""" generate an anisotropic Gaussian kernel
|
|
||||||
Args:
|
|
||||||
ksize : e.g., 15, kernel size
|
|
||||||
theta : [0, pi], rotation angle range
|
|
||||||
l1 : [0.1,50], scaling of eigenvalues
|
|
||||||
l2 : [0.1,l1], scaling of eigenvalues
|
|
||||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
|
||||||
Returns:
|
|
||||||
k : kernel
|
|
||||||
"""
|
|
||||||
|
|
||||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
|
||||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
|
||||||
D = np.array([[l1, 0], [0, l2]])
|
|
||||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
|
||||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
|
||||||
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def gm_blur_kernel(mean, cov, size=15):
|
|
||||||
center = size / 2.0 + 0.5
|
|
||||||
k = np.zeros([size, size])
|
|
||||||
for y in range(size):
|
|
||||||
for x in range(size):
|
|
||||||
cy = y - center + 1
|
|
||||||
cx = x - center + 1
|
|
||||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
|
||||||
|
|
||||||
k = k / np.sum(k)
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def shift_pixel(x, sf, upper_left=True):
|
|
||||||
"""shift pixel for super-resolution with different scale factors
|
|
||||||
Args:
|
|
||||||
x: WxHxC or WxH
|
|
||||||
sf: scale factor
|
|
||||||
upper_left: shift direction
|
|
||||||
"""
|
|
||||||
h, w = x.shape[:2]
|
|
||||||
shift = (sf - 1) * 0.5
|
|
||||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
|
||||||
if upper_left:
|
|
||||||
x1 = xv + shift
|
|
||||||
y1 = yv + shift
|
|
||||||
else:
|
|
||||||
x1 = xv - shift
|
|
||||||
y1 = yv - shift
|
|
||||||
|
|
||||||
x1 = np.clip(x1, 0, w - 1)
|
|
||||||
y1 = np.clip(y1, 0, h - 1)
|
|
||||||
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = interp2d(xv, yv, x)(x1, y1)
|
|
||||||
if x.ndim == 3:
|
|
||||||
for i in range(x.shape[-1]):
|
|
||||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def blur(x, k):
|
|
||||||
'''
|
|
||||||
x: image, NxcxHxW
|
|
||||||
k: kernel, Nx1xhxw
|
|
||||||
'''
|
|
||||||
n, c = x.shape[:2]
|
|
||||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
|
||||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
|
||||||
k = k.repeat(1, c, 1, 1)
|
|
||||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
|
||||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
|
||||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
|
||||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
|
||||||
""""
|
|
||||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
|
||||||
# Kai Zhang
|
|
||||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
|
||||||
# max_var = 2.5 * sf
|
|
||||||
"""
|
|
||||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
|
||||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
theta = np.random.rand() * np.pi # random theta
|
|
||||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
|
||||||
|
|
||||||
# Set COV matrix using Lambdas and Theta
|
|
||||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
|
||||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
|
||||||
[np.sin(theta), np.cos(theta)]])
|
|
||||||
SIGMA = Q @ LAMBDA @ Q.T
|
|
||||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
|
||||||
|
|
||||||
# Set expectation position (shifting kernel for aligned image)
|
|
||||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
|
||||||
MU = MU[None, None, :, None]
|
|
||||||
|
|
||||||
# Create meshgrid for Gaussian
|
|
||||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
|
||||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
|
||||||
|
|
||||||
# Calcualte Gaussian for every pixel of the kernel
|
|
||||||
ZZ = Z - MU
|
|
||||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
|
||||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
|
||||||
|
|
||||||
# shift the kernel so it will be centered
|
|
||||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
|
||||||
|
|
||||||
# Normalize the kernel and return
|
|
||||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
|
||||||
kernel = raw_kernel / np.sum(raw_kernel)
|
|
||||||
return kernel
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_gaussian(hsize, sigma):
|
|
||||||
hsize = [hsize, hsize]
|
|
||||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
|
||||||
std = sigma
|
|
||||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
|
||||||
arg = -(x * x + y * y) / (2 * std * std)
|
|
||||||
h = np.exp(arg)
|
|
||||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
|
||||||
sumh = h.sum()
|
|
||||||
if sumh != 0:
|
|
||||||
h = h / sumh
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_laplacian(alpha):
|
|
||||||
alpha = max([0, min([alpha, 1])])
|
|
||||||
h1 = alpha / (alpha + 1)
|
|
||||||
h2 = (1 - alpha) / (alpha + 1)
|
|
||||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
|
||||||
h = np.array(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial(filter_type, *args, **kwargs):
|
|
||||||
'''
|
|
||||||
python code from:
|
|
||||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
|
||||||
'''
|
|
||||||
if filter_type == 'gaussian':
|
|
||||||
return fspecial_gaussian(*args, **kwargs)
|
|
||||||
if filter_type == 'laplacian':
|
|
||||||
return fspecial_laplacian(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# degradation models
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def bicubic_degradation(x, sf=3):
|
|
||||||
'''
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
bicubicly downsampled LR image
|
|
||||||
'''
|
|
||||||
x = util.imresize_np(x, scale=1 / sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def srmd_degradation(x, k, sf=3):
|
|
||||||
''' blur + bicubic downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2018learning,
|
|
||||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={3262--3271},
|
|
||||||
year={2018}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def dpsr_degradation(x, k, sf=3):
|
|
||||||
''' bicubic downsampling + blur
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2019deep,
|
|
||||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={1671--1681},
|
|
||||||
year={2019}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def classical_degradation(x, k, sf=3):
|
|
||||||
''' blur + downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]/[0, 255]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
'''
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
|
||||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
|
||||||
st = 0
|
|
||||||
return x[st::sf, st::sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
|
||||||
"""USM sharpening. borrowed from real-ESRGAN
|
|
||||||
Input image: I; Blurry image: B.
|
|
||||||
1. K = I + weight * (I - B)
|
|
||||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
|
||||||
3. Blur mask:
|
|
||||||
4. Out = Mask * K + (1 - Mask) * I
|
|
||||||
Args:
|
|
||||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
|
||||||
weight (float): Sharp weight. Default: 1.
|
|
||||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
|
||||||
threshold (int):
|
|
||||||
"""
|
|
||||||
if radius % 2 == 0:
|
|
||||||
radius += 1
|
|
||||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
|
||||||
residual = img - blur
|
|
||||||
mask = np.abs(residual) * 255 > threshold
|
|
||||||
mask = mask.astype('float32')
|
|
||||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
|
||||||
|
|
||||||
K = img + weight * residual
|
|
||||||
K = np.clip(K, 0, 1)
|
|
||||||
return soft_mask * K + (1 - soft_mask) * img
|
|
||||||
|
|
||||||
|
|
||||||
def add_blur(img, sf=4):
|
|
||||||
wd2 = 4.0 + sf
|
|
||||||
wd = 2.0 + 0.2 * sf
|
|
||||||
if random.random() < 0.5:
|
|
||||||
l1 = wd2 * random.random()
|
|
||||||
l2 = wd2 * random.random()
|
|
||||||
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
|
||||||
else:
|
|
||||||
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_resize(img, sf=4):
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.8: # up
|
|
||||||
sf1 = random.uniform(1, 2)
|
|
||||||
elif rnum < 0.7: # down
|
|
||||||
sf1 = random.uniform(0.5 / sf, 1)
|
|
||||||
else:
|
|
||||||
sf1 = 1.0
|
|
||||||
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
# noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
# rnum = np.random.rand()
|
|
||||||
# if rnum > 0.6: # add color Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
# else: # add noise
|
|
||||||
# L = noise_level2 / 255.
|
|
||||||
# D = np.diag(np.random.rand(3))
|
|
||||||
# U = orth(np.random.rand(3, 3))
|
|
||||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
# img = np.clip(img, 0.0, 1.0)
|
|
||||||
# return img
|
|
||||||
|
|
||||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.6: # add color Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else: # add noise
|
|
||||||
L = noise_level2 / 255.
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
rnum = random.random()
|
|
||||||
if rnum > 0.6:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else:
|
|
||||||
L = noise_level2 / 255.
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Poisson_noise(img):
|
|
||||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
|
||||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
|
||||||
if random.random() < 0.5:
|
|
||||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
|
||||||
else:
|
|
||||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
|
||||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
|
||||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
|
||||||
img += noise_gray[:, :, np.newaxis]
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_JPEG_noise(img):
|
|
||||||
quality_factor = random.randint(30, 95)
|
|
||||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
|
||||||
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
|
||||||
img = cv2.imdecode(encimg, 1)
|
|
||||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
|
||||||
h, w = lq.shape[:2]
|
|
||||||
rnd_h = random.randint(0, h - lq_patchsize)
|
|
||||||
rnd_w = random.randint(0, w - lq_patchsize)
|
|
||||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
|
||||||
|
|
||||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
|
||||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
|
|
||||||
return lq, hq
|
|
||||||
|
|
||||||
|
|
||||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
|
||||||
sf_ori = sf
|
|
||||||
|
|
||||||
h1, w1 = img.shape[:2]
|
|
||||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
|
||||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
|
||||||
|
|
||||||
hq = img.copy()
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]))
|
|
||||||
else:
|
|
||||||
img = util.imresize_np(img, 1 / 2, True)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 1:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = img.shape[1], img.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.75:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]))
|
|
||||||
else:
|
|
||||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
|
||||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
elif i == 6:
|
|
||||||
# add processed camera sensor noise
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
|
||||||
|
|
||||||
return img, hq
|
|
||||||
|
|
||||||
|
|
||||||
# todo no isp_model?
|
|
||||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
image = util.uint2single(image)
|
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
|
||||||
sf_ori = sf
|
|
||||||
|
|
||||||
h1, w1 = image.shape[:2]
|
|
||||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
|
|
||||||
hq = image.copy()
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]))
|
|
||||||
else:
|
|
||||||
image = util.imresize_np(image, 1 / 2, True)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
elif i == 1:
|
|
||||||
image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = image.shape[1], image.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.75:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]))
|
|
||||||
else:
|
|
||||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
|
||||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
|
|
||||||
# elif i == 6:
|
|
||||||
# # add processed camera sensor noise
|
|
||||||
# if random.random() < isp_prob and isp_model is not None:
|
|
||||||
# with torch.no_grad():
|
|
||||||
# img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
image = util.single2uint(image)
|
|
||||||
example = {"image":image}
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
|
|
||||||
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is an extended degradation model by combining
|
|
||||||
the degradation models of BSRGAN and Real-ESRGAN
|
|
||||||
----------
|
|
||||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
|
||||||
sf: scale factor
|
|
||||||
use_shuffle: the degradation shuffle
|
|
||||||
use_sharp: sharpening the img
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
|
|
||||||
h1, w1 = img.shape[:2]
|
|
||||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
|
||||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
|
||||||
|
|
||||||
if use_sharp:
|
|
||||||
img = add_sharpening(img)
|
|
||||||
hq = img.copy()
|
|
||||||
|
|
||||||
if random.random() < shuffle_prob:
|
|
||||||
shuffle_order = random.sample(range(13), 13)
|
|
||||||
else:
|
|
||||||
shuffle_order = list(range(13))
|
|
||||||
# local shuffle for noise, JPEG is always the last one
|
|
||||||
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
|
||||||
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
|
||||||
|
|
||||||
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
elif i == 1:
|
|
||||||
img = add_resize(img, sf=sf)
|
|
||||||
elif i == 2:
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
|
||||||
elif i == 3:
|
|
||||||
if random.random() < poisson_prob:
|
|
||||||
img = add_Poisson_noise(img)
|
|
||||||
elif i == 4:
|
|
||||||
if random.random() < speckle_prob:
|
|
||||||
img = add_speckle_noise(img)
|
|
||||||
elif i == 5:
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
elif i == 6:
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
elif i == 7:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
elif i == 8:
|
|
||||||
img = add_resize(img, sf=sf)
|
|
||||||
elif i == 9:
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
|
||||||
elif i == 10:
|
|
||||||
if random.random() < poisson_prob:
|
|
||||||
img = add_Poisson_noise(img)
|
|
||||||
elif i == 11:
|
|
||||||
if random.random() < speckle_prob:
|
|
||||||
img = add_speckle_noise(img)
|
|
||||||
elif i == 12:
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
else:
|
|
||||||
print('check the shuffle!')
|
|
||||||
|
|
||||||
# resize to desired size
|
|
||||||
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]))
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
|
||||||
|
|
||||||
return img, hq
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print("hey")
|
|
||||||
img = util.imread_uint('utils/test.png', 3)
|
|
||||||
print(img)
|
|
||||||
img = util.uint2single(img)
|
|
||||||
print(img)
|
|
||||||
img = img[:448, :448]
|
|
||||||
h = img.shape[0] // 4
|
|
||||||
print("resizing to", h)
|
|
||||||
sf = 4
|
|
||||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
|
||||||
for i in range(20):
|
|
||||||
print(i)
|
|
||||||
img_lq = deg_fn(img)
|
|
||||||
print(img_lq)
|
|
||||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
|
||||||
print(img_lq.shape)
|
|
||||||
print("bicubic", img_lq_bicubic.shape)
|
|
||||||
print(img_hq.shape)
|
|
||||||
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0)
|
|
||||||
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0)
|
|
||||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
|
||||||
util.imsave(img_concat, str(i) + '.png')
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,651 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
import random
|
|
||||||
from scipy import ndimage
|
|
||||||
import scipy
|
|
||||||
import scipy.stats as ss
|
|
||||||
from scipy.interpolate import interp2d
|
|
||||||
from scipy.linalg import orth
|
|
||||||
import albumentations
|
|
||||||
|
|
||||||
import ldm.modules.image_degradation.utils_image as util
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Super-Resolution
|
|
||||||
# --------------------------------------------
|
|
||||||
#
|
|
||||||
# Kai Zhang (cskaizhang@gmail.com)
|
|
||||||
# https://github.com/cszn
|
|
||||||
# From 2019/03--2021/08
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
def modcrop_np(img, sf):
|
|
||||||
'''
|
|
||||||
Args:
|
|
||||||
img: numpy image, WxH or WxHxC
|
|
||||||
sf: scale factor
|
|
||||||
Return:
|
|
||||||
cropped image
|
|
||||||
'''
|
|
||||||
w, h = img.shape[:2]
|
|
||||||
im = np.copy(img)
|
|
||||||
return im[:w - w % sf, :h - h % sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# anisotropic Gaussian kernels
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def analytic_kernel(k):
|
|
||||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
|
||||||
k_size = k.shape[0]
|
|
||||||
# Calculate the big kernels size
|
|
||||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
|
||||||
# Loop over the small kernel to fill the big one
|
|
||||||
for r in range(k_size):
|
|
||||||
for c in range(k_size):
|
|
||||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
|
||||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
|
||||||
crop = k_size // 2
|
|
||||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
|
||||||
# Normalize to 1
|
|
||||||
return cropped_big_k / cropped_big_k.sum()
|
|
||||||
|
|
||||||
|
|
||||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
|
||||||
""" generate an anisotropic Gaussian kernel
|
|
||||||
Args:
|
|
||||||
ksize : e.g., 15, kernel size
|
|
||||||
theta : [0, pi], rotation angle range
|
|
||||||
l1 : [0.1,50], scaling of eigenvalues
|
|
||||||
l2 : [0.1,l1], scaling of eigenvalues
|
|
||||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
|
||||||
Returns:
|
|
||||||
k : kernel
|
|
||||||
"""
|
|
||||||
|
|
||||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
|
||||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
|
||||||
D = np.array([[l1, 0], [0, l2]])
|
|
||||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
|
||||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
|
||||||
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def gm_blur_kernel(mean, cov, size=15):
|
|
||||||
center = size / 2.0 + 0.5
|
|
||||||
k = np.zeros([size, size])
|
|
||||||
for y in range(size):
|
|
||||||
for x in range(size):
|
|
||||||
cy = y - center + 1
|
|
||||||
cx = x - center + 1
|
|
||||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
|
||||||
|
|
||||||
k = k / np.sum(k)
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def shift_pixel(x, sf, upper_left=True):
|
|
||||||
"""shift pixel for super-resolution with different scale factors
|
|
||||||
Args:
|
|
||||||
x: WxHxC or WxH
|
|
||||||
sf: scale factor
|
|
||||||
upper_left: shift direction
|
|
||||||
"""
|
|
||||||
h, w = x.shape[:2]
|
|
||||||
shift = (sf - 1) * 0.5
|
|
||||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
|
||||||
if upper_left:
|
|
||||||
x1 = xv + shift
|
|
||||||
y1 = yv + shift
|
|
||||||
else:
|
|
||||||
x1 = xv - shift
|
|
||||||
y1 = yv - shift
|
|
||||||
|
|
||||||
x1 = np.clip(x1, 0, w - 1)
|
|
||||||
y1 = np.clip(y1, 0, h - 1)
|
|
||||||
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = interp2d(xv, yv, x)(x1, y1)
|
|
||||||
if x.ndim == 3:
|
|
||||||
for i in range(x.shape[-1]):
|
|
||||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def blur(x, k):
|
|
||||||
'''
|
|
||||||
x: image, NxcxHxW
|
|
||||||
k: kernel, Nx1xhxw
|
|
||||||
'''
|
|
||||||
n, c = x.shape[:2]
|
|
||||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
|
||||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
|
||||||
k = k.repeat(1, c, 1, 1)
|
|
||||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
|
||||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
|
||||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
|
||||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
|
||||||
""""
|
|
||||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
|
||||||
# Kai Zhang
|
|
||||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
|
||||||
# max_var = 2.5 * sf
|
|
||||||
"""
|
|
||||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
|
||||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
theta = np.random.rand() * np.pi # random theta
|
|
||||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
|
||||||
|
|
||||||
# Set COV matrix using Lambdas and Theta
|
|
||||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
|
||||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
|
||||||
[np.sin(theta), np.cos(theta)]])
|
|
||||||
SIGMA = Q @ LAMBDA @ Q.T
|
|
||||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
|
||||||
|
|
||||||
# Set expectation position (shifting kernel for aligned image)
|
|
||||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
|
||||||
MU = MU[None, None, :, None]
|
|
||||||
|
|
||||||
# Create meshgrid for Gaussian
|
|
||||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
|
||||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
|
||||||
|
|
||||||
# Calcualte Gaussian for every pixel of the kernel
|
|
||||||
ZZ = Z - MU
|
|
||||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
|
||||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
|
||||||
|
|
||||||
# shift the kernel so it will be centered
|
|
||||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
|
||||||
|
|
||||||
# Normalize the kernel and return
|
|
||||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
|
||||||
kernel = raw_kernel / np.sum(raw_kernel)
|
|
||||||
return kernel
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_gaussian(hsize, sigma):
|
|
||||||
hsize = [hsize, hsize]
|
|
||||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
|
||||||
std = sigma
|
|
||||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
|
||||||
arg = -(x * x + y * y) / (2 * std * std)
|
|
||||||
h = np.exp(arg)
|
|
||||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
|
||||||
sumh = h.sum()
|
|
||||||
if sumh != 0:
|
|
||||||
h = h / sumh
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_laplacian(alpha):
|
|
||||||
alpha = max([0, min([alpha, 1])])
|
|
||||||
h1 = alpha / (alpha + 1)
|
|
||||||
h2 = (1 - alpha) / (alpha + 1)
|
|
||||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
|
||||||
h = np.array(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial(filter_type, *args, **kwargs):
|
|
||||||
'''
|
|
||||||
python code from:
|
|
||||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
|
||||||
'''
|
|
||||||
if filter_type == 'gaussian':
|
|
||||||
return fspecial_gaussian(*args, **kwargs)
|
|
||||||
if filter_type == 'laplacian':
|
|
||||||
return fspecial_laplacian(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# degradation models
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def bicubic_degradation(x, sf=3):
|
|
||||||
'''
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
bicubicly downsampled LR image
|
|
||||||
'''
|
|
||||||
x = util.imresize_np(x, scale=1 / sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def srmd_degradation(x, k, sf=3):
|
|
||||||
''' blur + bicubic downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2018learning,
|
|
||||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={3262--3271},
|
|
||||||
year={2018}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def dpsr_degradation(x, k, sf=3):
|
|
||||||
''' bicubic downsampling + blur
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2019deep,
|
|
||||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={1671--1681},
|
|
||||||
year={2019}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def classical_degradation(x, k, sf=3):
|
|
||||||
''' blur + downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]/[0, 255]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
'''
|
|
||||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
|
||||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
|
||||||
st = 0
|
|
||||||
return x[st::sf, st::sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
|
||||||
"""USM sharpening. borrowed from real-ESRGAN
|
|
||||||
Input image: I; Blurry image: B.
|
|
||||||
1. K = I + weight * (I - B)
|
|
||||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
|
||||||
3. Blur mask:
|
|
||||||
4. Out = Mask * K + (1 - Mask) * I
|
|
||||||
Args:
|
|
||||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
|
||||||
weight (float): Sharp weight. Default: 1.
|
|
||||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
|
||||||
threshold (int):
|
|
||||||
"""
|
|
||||||
if radius % 2 == 0:
|
|
||||||
radius += 1
|
|
||||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
|
||||||
residual = img - blur
|
|
||||||
mask = np.abs(residual) * 255 > threshold
|
|
||||||
mask = mask.astype('float32')
|
|
||||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
|
||||||
|
|
||||||
K = img + weight * residual
|
|
||||||
K = np.clip(K, 0, 1)
|
|
||||||
return soft_mask * K + (1 - soft_mask) * img
|
|
||||||
|
|
||||||
|
|
||||||
def add_blur(img, sf=4):
|
|
||||||
wd2 = 4.0 + sf
|
|
||||||
wd = 2.0 + 0.2 * sf
|
|
||||||
|
|
||||||
wd2 = wd2/4
|
|
||||||
wd = wd/4
|
|
||||||
|
|
||||||
if random.random() < 0.5:
|
|
||||||
l1 = wd2 * random.random()
|
|
||||||
l2 = wd2 * random.random()
|
|
||||||
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
|
||||||
else:
|
|
||||||
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
|
|
||||||
img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_resize(img, sf=4):
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.8: # up
|
|
||||||
sf1 = random.uniform(1, 2)
|
|
||||||
elif rnum < 0.7: # down
|
|
||||||
sf1 = random.uniform(0.5 / sf, 1)
|
|
||||||
else:
|
|
||||||
sf1 = 1.0
|
|
||||||
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
# noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
# rnum = np.random.rand()
|
|
||||||
# if rnum > 0.6: # add color Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
# else: # add noise
|
|
||||||
# L = noise_level2 / 255.
|
|
||||||
# D = np.diag(np.random.rand(3))
|
|
||||||
# U = orth(np.random.rand(3, 3))
|
|
||||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
# img = np.clip(img, 0.0, 1.0)
|
|
||||||
# return img
|
|
||||||
|
|
||||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.6: # add color Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else: # add noise
|
|
||||||
L = noise_level2 / 255.
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
rnum = random.random()
|
|
||||||
if rnum > 0.6:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else:
|
|
||||||
L = noise_level2 / 255.
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Poisson_noise(img):
|
|
||||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
|
||||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
|
||||||
if random.random() < 0.5:
|
|
||||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
|
||||||
else:
|
|
||||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
|
||||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
|
||||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
|
||||||
img += noise_gray[:, :, np.newaxis]
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_JPEG_noise(img):
|
|
||||||
quality_factor = random.randint(80, 95)
|
|
||||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
|
||||||
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
|
||||||
img = cv2.imdecode(encimg, 1)
|
|
||||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
|
||||||
h, w = lq.shape[:2]
|
|
||||||
rnd_h = random.randint(0, h - lq_patchsize)
|
|
||||||
rnd_w = random.randint(0, w - lq_patchsize)
|
|
||||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
|
||||||
|
|
||||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
|
||||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
|
|
||||||
return lq, hq
|
|
||||||
|
|
||||||
|
|
||||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
|
||||||
sf_ori = sf
|
|
||||||
|
|
||||||
h1, w1 = img.shape[:2]
|
|
||||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
|
||||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
|
||||||
|
|
||||||
hq = img.copy()
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]))
|
|
||||||
else:
|
|
||||||
img = util.imresize_np(img, 1 / 2, True)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 1:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = img.shape[1], img.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.75:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]))
|
|
||||||
else:
|
|
||||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
|
||||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
elif i == 6:
|
|
||||||
# add processed camera sensor noise
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
|
||||||
|
|
||||||
return img, hq
|
|
||||||
|
|
||||||
|
|
||||||
# todo no isp_model?
|
|
||||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
image = util.uint2single(image)
|
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
|
||||||
sf_ori = sf
|
|
||||||
|
|
||||||
h1, w1 = image.shape[:2]
|
|
||||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
|
|
||||||
hq = image.copy()
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]))
|
|
||||||
else:
|
|
||||||
image = util.imresize_np(image, 1 / 2, True)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
# elif i == 1:
|
|
||||||
# image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = image.shape[1], image.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.8:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]))
|
|
||||||
else:
|
|
||||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
|
||||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
#
|
|
||||||
# elif i == 6:
|
|
||||||
# # add processed camera sensor noise
|
|
||||||
# if random.random() < isp_prob and isp_model is not None:
|
|
||||||
# with torch.no_grad():
|
|
||||||
# img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
image = util.single2uint(image)
|
|
||||||
if up:
|
|
||||||
image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
|
|
||||||
example = {"image": image}
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print("hey")
|
|
||||||
img = util.imread_uint('utils/test.png', 3)
|
|
||||||
img = img[:448, :448]
|
|
||||||
h = img.shape[0] // 4
|
|
||||||
print("resizing to", h)
|
|
||||||
sf = 4
|
|
||||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
|
||||||
for i in range(20):
|
|
||||||
print(i)
|
|
||||||
img_hq = img
|
|
||||||
img_lq = deg_fn(img)["image"]
|
|
||||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
|
||||||
print(img_lq)
|
|
||||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
|
|
||||||
print(img_lq.shape)
|
|
||||||
print("bicubic", img_lq_bicubic.shape)
|
|
||||||
print(img_hq.shape)
|
|
||||||
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0)
|
|
||||||
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
|
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0)
|
|
||||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
|
||||||
util.imsave(img_concat, str(i) + '.png')
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 431 KiB |
@ -1,916 +0,0 @@
|
|||||||
import os
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import cv2
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
from datetime import datetime
|
|
||||||
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
|
||||||
|
|
||||||
|
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# Kai Zhang (github: https://github.com/cszn)
|
|
||||||
# 03/Mar/2019
|
|
||||||
# --------------------------------------------
|
|
||||||
# https://github.com/twhui/SRGAN-pyTorch
|
|
||||||
# https://github.com/xinntao/BasicSR
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
|
|
||||||
|
|
||||||
|
|
||||||
def is_image_file(filename):
|
|
||||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp():
|
|
||||||
return datetime.now().strftime('%y%m%d-%H%M%S')
|
|
||||||
|
|
||||||
|
|
||||||
def imshow(x, title=None, cbar=False, figsize=None):
|
|
||||||
plt.figure(figsize=figsize)
|
|
||||||
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
|
|
||||||
if title:
|
|
||||||
plt.title(title)
|
|
||||||
if cbar:
|
|
||||||
plt.colorbar()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def surf(Z, cmap='rainbow', figsize=None):
|
|
||||||
plt.figure(figsize=figsize)
|
|
||||||
ax3 = plt.axes(projection='3d')
|
|
||||||
|
|
||||||
w, h = Z.shape[:2]
|
|
||||||
xx = np.arange(0,w,1)
|
|
||||||
yy = np.arange(0,h,1)
|
|
||||||
X, Y = np.meshgrid(xx, yy)
|
|
||||||
ax3.plot_surface(X,Y,Z,cmap=cmap)
|
|
||||||
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# get image pathes
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_paths(dataroot):
|
|
||||||
paths = None # return None if dataroot is None
|
|
||||||
if dataroot is not None:
|
|
||||||
paths = sorted(_get_paths_from_images(dataroot))
|
|
||||||
return paths
|
|
||||||
|
|
||||||
|
|
||||||
def _get_paths_from_images(path):
|
|
||||||
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
|
||||||
images = []
|
|
||||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
|
||||||
for fname in sorted(fnames):
|
|
||||||
if is_image_file(fname):
|
|
||||||
img_path = os.path.join(dirpath, fname)
|
|
||||||
images.append(img_path)
|
|
||||||
assert images, '{:s} has no valid image file'.format(path)
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# split large images into small images
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
|
||||||
w, h = img.shape[:2]
|
|
||||||
patches = []
|
|
||||||
if w > p_max and h > p_max:
|
|
||||||
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
|
|
||||||
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
|
|
||||||
w1.append(w-p_size)
|
|
||||||
h1.append(h-p_size)
|
|
||||||
# print(w1)
|
|
||||||
# print(h1)
|
|
||||||
for i in w1:
|
|
||||||
for j in h1:
|
|
||||||
patches.append(img[i:i+p_size, j:j+p_size,:])
|
|
||||||
else:
|
|
||||||
patches.append(img)
|
|
||||||
|
|
||||||
return patches
|
|
||||||
|
|
||||||
|
|
||||||
def imssave(imgs, img_path):
|
|
||||||
"""
|
|
||||||
imgs: list, N images of size WxHxC
|
|
||||||
"""
|
|
||||||
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
|
||||||
|
|
||||||
for i, img in enumerate(imgs):
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = img[:, :, [2, 1, 0]]
|
|
||||||
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
|
|
||||||
cv2.imwrite(new_path, img)
|
|
||||||
|
|
||||||
|
|
||||||
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
|
|
||||||
"""
|
|
||||||
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
|
||||||
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
|
||||||
will be splitted.
|
|
||||||
Args:
|
|
||||||
original_dataroot:
|
|
||||||
taget_dataroot:
|
|
||||||
p_size: size of small images
|
|
||||||
p_overlap: patch size in training is a good choice
|
|
||||||
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
|
||||||
"""
|
|
||||||
paths = get_image_paths(original_dataroot)
|
|
||||||
for img_path in paths:
|
|
||||||
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
|
||||||
img = imread_uint(img_path, n_channels=n_channels)
|
|
||||||
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
|
||||||
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
|
|
||||||
#if original_dataroot == taget_dataroot:
|
|
||||||
#del img_path
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# makedir
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
def mkdir(path):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
|
|
||||||
def mkdirs(paths):
|
|
||||||
if isinstance(paths, str):
|
|
||||||
mkdir(paths)
|
|
||||||
else:
|
|
||||||
for path in paths:
|
|
||||||
mkdir(path)
|
|
||||||
|
|
||||||
|
|
||||||
def mkdir_and_rename(path):
|
|
||||||
if os.path.exists(path):
|
|
||||||
new_name = path + '_archived_' + get_timestamp()
|
|
||||||
print('Path already exists. Rename it to [{:s}]'.format(new_name))
|
|
||||||
os.rename(path, new_name)
|
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# read image from path
|
|
||||||
# opencv is fast, but read BGR numpy image
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# get uint8 image of size HxWxn_channles (RGB)
|
|
||||||
# --------------------------------------------
|
|
||||||
def imread_uint(path, n_channels=3):
|
|
||||||
# input: path
|
|
||||||
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
|
||||||
if n_channels == 1:
|
|
||||||
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
|
||||||
img = np.expand_dims(img, axis=2) # HxWx1
|
|
||||||
elif n_channels == 3:
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
|
||||||
else:
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# matlab's imwrite
|
|
||||||
# --------------------------------------------
|
|
||||||
def imsave(img, img_path):
|
|
||||||
img = np.squeeze(img)
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = img[:, :, [2, 1, 0]]
|
|
||||||
cv2.imwrite(img_path, img)
|
|
||||||
|
|
||||||
def imwrite(img, img_path):
|
|
||||||
img = np.squeeze(img)
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = img[:, :, [2, 1, 0]]
|
|
||||||
cv2.imwrite(img_path, img)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# get single image of size HxWxn_channles (BGR)
|
|
||||||
# --------------------------------------------
|
|
||||||
def read_img(path):
|
|
||||||
# read image by cv2
|
|
||||||
# return: Numpy float32, HWC, BGR, [0,1]
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
|
||||||
img = img.astype(np.float32) / 255.
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
# some images have 4 channels
|
|
||||||
if img.shape[2] > 3:
|
|
||||||
img = img[:, :, :3]
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# image format conversion
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(single) <---> numpy(unit)
|
|
||||||
# numpy(single) <---> tensor
|
|
||||||
# numpy(unit) <---> tensor
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(single) [0, 1] <---> numpy(unit)
|
|
||||||
# --------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def uint2single(img):
|
|
||||||
|
|
||||||
return np.float32(img/255.)
|
|
||||||
|
|
||||||
|
|
||||||
def single2uint(img):
|
|
||||||
|
|
||||||
return np.uint8((img.clip(0, 1)*255.).round())
|
|
||||||
|
|
||||||
|
|
||||||
def uint162single(img):
|
|
||||||
|
|
||||||
return np.float32(img/65535.)
|
|
||||||
|
|
||||||
|
|
||||||
def single2uint16(img):
|
|
||||||
|
|
||||||
return np.uint16((img.clip(0, 1)*65535.).round())
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(unit) (HxWxC or HxW) <---> tensor
|
|
||||||
# --------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
# convert uint to 4-dimensional torch tensor
|
|
||||||
def uint2tensor4(img):
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# convert uint to 3-dimensional torch tensor
|
|
||||||
def uint2tensor3(img):
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
|
|
||||||
|
|
||||||
|
|
||||||
# convert 2/3/4-dimensional torch tensor to uint
|
|
||||||
def tensor2uint(img):
|
|
||||||
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
return np.uint8((img*255.0).round())
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(single) (HxWxC) <---> tensor
|
|
||||||
# --------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
# convert single (HxWxC) to 3-dimensional torch tensor
|
|
||||||
def single2tensor3(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
|
||||||
|
|
||||||
|
|
||||||
# convert single (HxWxC) to 4-dimensional torch tensor
|
|
||||||
def single2tensor4(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# convert torch tensor to single
|
|
||||||
def tensor2single(img):
|
|
||||||
img = img.data.squeeze().float().cpu().numpy()
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
# convert torch tensor to single
|
|
||||||
def tensor2single3(img):
|
|
||||||
img = img.data.squeeze().float().cpu().numpy()
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
elif img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def single2tensor5(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def single32tensor5(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def single42tensor4(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
|
||||||
|
|
||||||
|
|
||||||
# from skimage.io import imread, imsave
|
|
||||||
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
|
||||||
'''
|
|
||||||
Converts a torch Tensor into an image Numpy array of BGR channel order
|
|
||||||
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
|
||||||
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
|
||||||
'''
|
|
||||||
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
|
||||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
|
||||||
n_dim = tensor.dim()
|
|
||||||
if n_dim == 4:
|
|
||||||
n_img = len(tensor)
|
|
||||||
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
|
||||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
|
||||||
elif n_dim == 3:
|
|
||||||
img_np = tensor.numpy()
|
|
||||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
|
||||||
elif n_dim == 2:
|
|
||||||
img_np = tensor.numpy()
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
|
||||||
if out_type == np.uint8:
|
|
||||||
img_np = (img_np * 255.0).round()
|
|
||||||
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
|
||||||
return img_np.astype(out_type)
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# Augmentation, flipe and/or rotate
|
|
||||||
# --------------------------------------------
|
|
||||||
# The following two are enough.
|
|
||||||
# (1) augmet_img: numpy image of WxHxC or WxH
|
|
||||||
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img(img, mode=0):
|
|
||||||
'''Kai Zhang (github: https://github.com/cszn)
|
|
||||||
'''
|
|
||||||
if mode == 0:
|
|
||||||
return img
|
|
||||||
elif mode == 1:
|
|
||||||
return np.flipud(np.rot90(img))
|
|
||||||
elif mode == 2:
|
|
||||||
return np.flipud(img)
|
|
||||||
elif mode == 3:
|
|
||||||
return np.rot90(img, k=3)
|
|
||||||
elif mode == 4:
|
|
||||||
return np.flipud(np.rot90(img, k=2))
|
|
||||||
elif mode == 5:
|
|
||||||
return np.rot90(img)
|
|
||||||
elif mode == 6:
|
|
||||||
return np.rot90(img, k=2)
|
|
||||||
elif mode == 7:
|
|
||||||
return np.flipud(np.rot90(img, k=3))
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img_tensor4(img, mode=0):
|
|
||||||
'''Kai Zhang (github: https://github.com/cszn)
|
|
||||||
'''
|
|
||||||
if mode == 0:
|
|
||||||
return img
|
|
||||||
elif mode == 1:
|
|
||||||
return img.rot90(1, [2, 3]).flip([2])
|
|
||||||
elif mode == 2:
|
|
||||||
return img.flip([2])
|
|
||||||
elif mode == 3:
|
|
||||||
return img.rot90(3, [2, 3])
|
|
||||||
elif mode == 4:
|
|
||||||
return img.rot90(2, [2, 3]).flip([2])
|
|
||||||
elif mode == 5:
|
|
||||||
return img.rot90(1, [2, 3])
|
|
||||||
elif mode == 6:
|
|
||||||
return img.rot90(2, [2, 3])
|
|
||||||
elif mode == 7:
|
|
||||||
return img.rot90(3, [2, 3]).flip([2])
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img_tensor(img, mode=0):
|
|
||||||
'''Kai Zhang (github: https://github.com/cszn)
|
|
||||||
'''
|
|
||||||
img_size = img.size()
|
|
||||||
img_np = img.data.cpu().numpy()
|
|
||||||
if len(img_size) == 3:
|
|
||||||
img_np = np.transpose(img_np, (1, 2, 0))
|
|
||||||
elif len(img_size) == 4:
|
|
||||||
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
|
||||||
img_np = augment_img(img_np, mode=mode)
|
|
||||||
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
|
||||||
if len(img_size) == 3:
|
|
||||||
img_tensor = img_tensor.permute(2, 0, 1)
|
|
||||||
elif len(img_size) == 4:
|
|
||||||
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
|
||||||
|
|
||||||
return img_tensor.type_as(img)
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img_np3(img, mode=0):
|
|
||||||
if mode == 0:
|
|
||||||
return img
|
|
||||||
elif mode == 1:
|
|
||||||
return img.transpose(1, 0, 2)
|
|
||||||
elif mode == 2:
|
|
||||||
return img[::-1, :, :]
|
|
||||||
elif mode == 3:
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
elif mode == 4:
|
|
||||||
return img[:, ::-1, :]
|
|
||||||
elif mode == 5:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
elif mode == 6:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
return img
|
|
||||||
elif mode == 7:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def augment_imgs(img_list, hflip=True, rot=True):
|
|
||||||
# horizontal flip OR rotate
|
|
||||||
hflip = hflip and random.random() < 0.5
|
|
||||||
vflip = rot and random.random() < 0.5
|
|
||||||
rot90 = rot and random.random() < 0.5
|
|
||||||
|
|
||||||
def _augment(img):
|
|
||||||
if hflip:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
if vflip:
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
if rot90:
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
return [_augment(img) for img in img_list]
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# modcrop and shave
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
def modcrop(img_in, scale):
|
|
||||||
# img_in: Numpy, HWC or HW
|
|
||||||
img = np.copy(img_in)
|
|
||||||
if img.ndim == 2:
|
|
||||||
H, W = img.shape
|
|
||||||
H_r, W_r = H % scale, W % scale
|
|
||||||
img = img[:H - H_r, :W - W_r]
|
|
||||||
elif img.ndim == 3:
|
|
||||||
H, W, C = img.shape
|
|
||||||
H_r, W_r = H % scale, W % scale
|
|
||||||
img = img[:H - H_r, :W - W_r, :]
|
|
||||||
else:
|
|
||||||
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def shave(img_in, border=0):
|
|
||||||
# img_in: Numpy, HWC or HW
|
|
||||||
img = np.copy(img_in)
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
img = img[border:h-border, border:w-border]
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# image processing process on numpy image
|
|
||||||
# channel_convert(in_c, tar_type, img_list):
|
|
||||||
# rgb2ycbcr(img, only_y=True):
|
|
||||||
# bgr2ycbcr(img, only_y=True):
|
|
||||||
# ycbcr2rgb(img):
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
def rgb2ycbcr(img, only_y=True):
|
|
||||||
'''same as matlab rgb2ycbcr
|
|
||||||
only_y: only return Y channel
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
'''
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.
|
|
||||||
# convert
|
|
||||||
if only_y:
|
|
||||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
|
||||||
else:
|
|
||||||
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
|
||||||
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
def ycbcr2rgb(img):
|
|
||||||
'''same as matlab ycbcr2rgb
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
'''
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.
|
|
||||||
# convert
|
|
||||||
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
|
||||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
def bgr2ycbcr(img, only_y=True):
|
|
||||||
'''bgr version of rgb2ycbcr
|
|
||||||
only_y: only return Y channel
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
'''
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.
|
|
||||||
# convert
|
|
||||||
if only_y:
|
|
||||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
|
||||||
else:
|
|
||||||
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
|
||||||
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
def channel_convert(in_c, tar_type, img_list):
|
|
||||||
# conversion among BGR, gray and y
|
|
||||||
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
|
||||||
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
|
||||||
return [np.expand_dims(img, axis=2) for img in gray_list]
|
|
||||||
elif in_c == 3 and tar_type == 'y': # BGR to y
|
|
||||||
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
|
||||||
return [np.expand_dims(img, axis=2) for img in y_list]
|
|
||||||
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
|
||||||
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
|
||||||
else:
|
|
||||||
return img_list
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# metric, PSNR and SSIM
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# PSNR
|
|
||||||
# --------------------------------------------
|
|
||||||
def calculate_psnr(img1, img2, border=0):
|
|
||||||
# img1 and img2 have range [0, 255]
|
|
||||||
#img1 = img1.squeeze()
|
|
||||||
#img2 = img2.squeeze()
|
|
||||||
if not img1.shape == img2.shape:
|
|
||||||
raise ValueError('Input images must have the same dimensions.')
|
|
||||||
h, w = img1.shape[:2]
|
|
||||||
img1 = img1[border:h-border, border:w-border]
|
|
||||||
img2 = img2[border:h-border, border:w-border]
|
|
||||||
|
|
||||||
img1 = img1.astype(np.float64)
|
|
||||||
img2 = img2.astype(np.float64)
|
|
||||||
mse = np.mean((img1 - img2)**2)
|
|
||||||
if mse == 0:
|
|
||||||
return float('inf')
|
|
||||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# SSIM
|
|
||||||
# --------------------------------------------
|
|
||||||
def calculate_ssim(img1, img2, border=0):
|
|
||||||
'''calculate SSIM
|
|
||||||
the same outputs as MATLAB's
|
|
||||||
img1, img2: [0, 255]
|
|
||||||
'''
|
|
||||||
#img1 = img1.squeeze()
|
|
||||||
#img2 = img2.squeeze()
|
|
||||||
if not img1.shape == img2.shape:
|
|
||||||
raise ValueError('Input images must have the same dimensions.')
|
|
||||||
h, w = img1.shape[:2]
|
|
||||||
img1 = img1[border:h-border, border:w-border]
|
|
||||||
img2 = img2[border:h-border, border:w-border]
|
|
||||||
|
|
||||||
if img1.ndim == 2:
|
|
||||||
return ssim(img1, img2)
|
|
||||||
elif img1.ndim == 3:
|
|
||||||
if img1.shape[2] == 3:
|
|
||||||
ssims = []
|
|
||||||
for i in range(3):
|
|
||||||
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
|
|
||||||
return np.array(ssims).mean()
|
|
||||||
elif img1.shape[2] == 1:
|
|
||||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
|
||||||
else:
|
|
||||||
raise ValueError('Wrong input image dimensions.')
|
|
||||||
|
|
||||||
|
|
||||||
def ssim(img1, img2):
|
|
||||||
C1 = (0.01 * 255)**2
|
|
||||||
C2 = (0.03 * 255)**2
|
|
||||||
|
|
||||||
img1 = img1.astype(np.float64)
|
|
||||||
img2 = img2.astype(np.float64)
|
|
||||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
|
||||||
window = np.outer(kernel, kernel.transpose())
|
|
||||||
|
|
||||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
|
||||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
|
||||||
mu1_sq = mu1**2
|
|
||||||
mu2_sq = mu2**2
|
|
||||||
mu1_mu2 = mu1 * mu2
|
|
||||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
|
||||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
|
||||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
|
||||||
|
|
||||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
|
||||||
(sigma1_sq + sigma2_sq + C2))
|
|
||||||
return ssim_map.mean()
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# --------------------------------------------
|
|
||||||
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
|
||||||
# --------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
# matlab 'imresize' function, now only support 'bicubic'
|
|
||||||
def cubic(x):
|
|
||||||
absx = torch.abs(x)
|
|
||||||
absx2 = absx**2
|
|
||||||
absx3 = absx**3
|
|
||||||
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
|
|
||||||
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
|
||||||
if (scale < 1) and (antialiasing):
|
|
||||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
|
||||||
kernel_width = kernel_width / scale
|
|
||||||
|
|
||||||
# Output-space coordinates
|
|
||||||
x = torch.linspace(1, out_length, out_length)
|
|
||||||
|
|
||||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
|
||||||
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
|
||||||
# space maps to 1.5 in input space.
|
|
||||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
|
||||||
|
|
||||||
# What is the left-most pixel that can be involved in the computation?
|
|
||||||
left = torch.floor(u - kernel_width / 2)
|
|
||||||
|
|
||||||
# What is the maximum number of pixels that can be involved in the
|
|
||||||
# computation? Note: it's OK to use an extra pixel here; if the
|
|
||||||
# corresponding weights are all zero, it will be eliminated at the end
|
|
||||||
# of this function.
|
|
||||||
P = math.ceil(kernel_width) + 2
|
|
||||||
|
|
||||||
# The indices of the input pixels involved in computing the k-th output
|
|
||||||
# pixel are in row k of the indices matrix.
|
|
||||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
|
||||||
1, P).expand(out_length, P)
|
|
||||||
|
|
||||||
# The weights used to compute the k-th output pixel are in row k of the
|
|
||||||
# weights matrix.
|
|
||||||
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
|
||||||
# apply cubic kernel
|
|
||||||
if (scale < 1) and (antialiasing):
|
|
||||||
weights = scale * cubic(distance_to_center * scale)
|
|
||||||
else:
|
|
||||||
weights = cubic(distance_to_center)
|
|
||||||
# Normalize the weights matrix so that each row sums to 1.
|
|
||||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
|
||||||
weights = weights / weights_sum.expand(out_length, P)
|
|
||||||
|
|
||||||
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
|
||||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
|
||||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
|
||||||
indices = indices.narrow(1, 1, P - 2)
|
|
||||||
weights = weights.narrow(1, 1, P - 2)
|
|
||||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
|
||||||
indices = indices.narrow(1, 0, P - 2)
|
|
||||||
weights = weights.narrow(1, 0, P - 2)
|
|
||||||
weights = weights.contiguous()
|
|
||||||
indices = indices.contiguous()
|
|
||||||
sym_len_s = -indices.min() + 1
|
|
||||||
sym_len_e = indices.max() - in_length
|
|
||||||
indices = indices + sym_len_s - 1
|
|
||||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# imresize for tensor image [0, 1]
|
|
||||||
# --------------------------------------------
|
|
||||||
def imresize(img, scale, antialiasing=True):
|
|
||||||
# Now the scale should be the same for H and W
|
|
||||||
# input: img: pytorch tensor, CHW or HW [0,1]
|
|
||||||
# output: CHW or HW [0,1] w/o round
|
|
||||||
need_squeeze = True if img.dim() == 2 else False
|
|
||||||
if need_squeeze:
|
|
||||||
img.unsqueeze_(0)
|
|
||||||
in_C, in_H, in_W = img.size()
|
|
||||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
|
||||||
kernel_width = 4
|
|
||||||
kernel = 'cubic'
|
|
||||||
|
|
||||||
# Return the desired dimension order for performing the resize. The
|
|
||||||
# strategy is to perform the resize first along the dimension with the
|
|
||||||
# smallest scale factor.
|
|
||||||
# Now we do not support this.
|
|
||||||
|
|
||||||
# get weights and indices
|
|
||||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
|
||||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
|
||||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
|
||||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
|
||||||
# process H dimension
|
|
||||||
# symmetric copying
|
|
||||||
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
|
||||||
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
|
||||||
|
|
||||||
sym_patch = img[:, :sym_len_Hs, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = img[:, -sym_len_He:, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
|
||||||
kernel_width = weights_H.size(1)
|
|
||||||
for i in range(out_H):
|
|
||||||
idx = int(indices_H[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
|
||||||
|
|
||||||
# process W dimension
|
|
||||||
# symmetric copying
|
|
||||||
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
|
||||||
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, :, :sym_len_Ws]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
|
||||||
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, :, -sym_len_We:]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
|
||||||
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
|
||||||
kernel_width = weights_W.size(1)
|
|
||||||
for i in range(out_W):
|
|
||||||
idx = int(indices_W[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
|
|
||||||
if need_squeeze:
|
|
||||||
out_2.squeeze_()
|
|
||||||
return out_2
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# imresize for numpy image [0, 1]
|
|
||||||
# --------------------------------------------
|
|
||||||
def imresize_np(img, scale, antialiasing=True):
|
|
||||||
# Now the scale should be the same for H and W
|
|
||||||
# input: img: Numpy, HWC or HW [0,1]
|
|
||||||
# output: HWC or HW [0,1] w/o round
|
|
||||||
img = torch.from_numpy(img)
|
|
||||||
need_squeeze = True if img.dim() == 2 else False
|
|
||||||
if need_squeeze:
|
|
||||||
img.unsqueeze_(2)
|
|
||||||
|
|
||||||
in_H, in_W, in_C = img.size()
|
|
||||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
|
||||||
kernel_width = 4
|
|
||||||
kernel = 'cubic'
|
|
||||||
|
|
||||||
# Return the desired dimension order for performing the resize. The
|
|
||||||
# strategy is to perform the resize first along the dimension with the
|
|
||||||
# smallest scale factor.
|
|
||||||
# Now we do not support this.
|
|
||||||
|
|
||||||
# get weights and indices
|
|
||||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
|
||||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
|
||||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
|
||||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
|
||||||
# process H dimension
|
|
||||||
# symmetric copying
|
|
||||||
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
|
||||||
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
|
||||||
|
|
||||||
sym_patch = img[:sym_len_Hs, :, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
|
||||||
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = img[-sym_len_He:, :, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
|
||||||
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
|
||||||
kernel_width = weights_H.size(1)
|
|
||||||
for i in range(out_H):
|
|
||||||
idx = int(indices_H[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
|
||||||
|
|
||||||
# process W dimension
|
|
||||||
# symmetric copying
|
|
||||||
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
|
||||||
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, :sym_len_Ws, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, -sym_len_We:, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
|
||||||
kernel_width = weights_W.size(1)
|
|
||||||
for i in range(out_W):
|
|
||||||
idx = int(indices_W[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
|
|
||||||
if need_squeeze:
|
|
||||||
out_2.squeeze_()
|
|
||||||
|
|
||||||
return out_2.numpy()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print('---')
|
|
||||||
# img = imread_uint('test.bmp', 3)
|
|
||||||
# img = uint2single(img)
|
|
||||||
# img_bicubic = imresize_np(img, 1/4)
|
|
||||||
@ -1,170 +0,0 @@
|
|||||||
# based on https://github.com/isl-org/MiDaS
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torchvision.transforms import Compose
|
|
||||||
|
|
||||||
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
|
|
||||||
from ldm.modules.midas.midas.midas_net import MidasNet
|
|
||||||
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
|
|
||||||
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
|
||||||
|
|
||||||
|
|
||||||
ISL_PATHS = {
|
|
||||||
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
|
|
||||||
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
|
|
||||||
"midas_v21": "",
|
|
||||||
"midas_v21_small": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def disabled_train(self, mode=True):
|
|
||||||
"""Overwrite model.train with this function to make sure train/eval mode
|
|
||||||
does not change anymore."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def load_midas_transform(model_type):
|
|
||||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
|
||||||
# load transform only
|
|
||||||
if model_type == "dpt_large": # DPT-Large
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "minimal"
|
|
||||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
|
|
||||||
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "minimal"
|
|
||||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
|
|
||||||
elif model_type == "midas_v21":
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "upper_bound"
|
|
||||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
|
|
||||||
elif model_type == "midas_v21_small":
|
|
||||||
net_w, net_h = 256, 256
|
|
||||||
resize_mode = "upper_bound"
|
|
||||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
|
||||||
|
|
||||||
transform = Compose(
|
|
||||||
[
|
|
||||||
Resize(
|
|
||||||
net_w,
|
|
||||||
net_h,
|
|
||||||
resize_target=None,
|
|
||||||
keep_aspect_ratio=True,
|
|
||||||
ensure_multiple_of=32,
|
|
||||||
resize_method=resize_mode,
|
|
||||||
image_interpolation_method=cv2.INTER_CUBIC,
|
|
||||||
),
|
|
||||||
normalization,
|
|
||||||
PrepareForNet(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return transform
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_type):
|
|
||||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
|
||||||
# load network
|
|
||||||
model_path = ISL_PATHS[model_type]
|
|
||||||
if model_type == "dpt_large": # DPT-Large
|
|
||||||
model = DPTDepthModel(
|
|
||||||
path=model_path,
|
|
||||||
backbone="vitl16_384",
|
|
||||||
non_negative=True,
|
|
||||||
)
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "minimal"
|
|
||||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
|
|
||||||
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
|
||||||
model = DPTDepthModel(
|
|
||||||
path=model_path,
|
|
||||||
backbone="vitb_rn50_384",
|
|
||||||
non_negative=True,
|
|
||||||
)
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "minimal"
|
|
||||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
|
|
||||||
elif model_type == "midas_v21":
|
|
||||||
model = MidasNet(model_path, non_negative=True)
|
|
||||||
net_w, net_h = 384, 384
|
|
||||||
resize_mode = "upper_bound"
|
|
||||||
normalization = NormalizeImage(
|
|
||||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == "midas_v21_small":
|
|
||||||
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
|
||||||
non_negative=True, blocks={'expand': True})
|
|
||||||
net_w, net_h = 256, 256
|
|
||||||
resize_mode = "upper_bound"
|
|
||||||
normalization = NormalizeImage(
|
|
||||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
|
||||||
assert False
|
|
||||||
|
|
||||||
transform = Compose(
|
|
||||||
[
|
|
||||||
Resize(
|
|
||||||
net_w,
|
|
||||||
net_h,
|
|
||||||
resize_target=None,
|
|
||||||
keep_aspect_ratio=True,
|
|
||||||
ensure_multiple_of=32,
|
|
||||||
resize_method=resize_mode,
|
|
||||||
image_interpolation_method=cv2.INTER_CUBIC,
|
|
||||||
),
|
|
||||||
normalization,
|
|
||||||
PrepareForNet(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return model.eval(), transform
|
|
||||||
|
|
||||||
|
|
||||||
class MiDaSInference(nn.Module):
|
|
||||||
MODEL_TYPES_TORCH_HUB = [
|
|
||||||
"DPT_Large",
|
|
||||||
"DPT_Hybrid",
|
|
||||||
"MiDaS_small"
|
|
||||||
]
|
|
||||||
MODEL_TYPES_ISL = [
|
|
||||||
"dpt_large",
|
|
||||||
"dpt_hybrid",
|
|
||||||
"midas_v21",
|
|
||||||
"midas_v21_small",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, model_type):
|
|
||||||
super().__init__()
|
|
||||||
assert (model_type in self.MODEL_TYPES_ISL)
|
|
||||||
model, _ = load_model(model_type)
|
|
||||||
self.model = model
|
|
||||||
self.model.train = disabled_train
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
|
|
||||||
# NOTE: we expect that the correct transform has been called during dataloading.
|
|
||||||
with torch.no_grad():
|
|
||||||
prediction = self.model(x)
|
|
||||||
prediction = torch.nn.functional.interpolate(
|
|
||||||
prediction.unsqueeze(1),
|
|
||||||
size=x.shape[2:],
|
|
||||||
mode="bicubic",
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
|
||||||
def load(self, path):
|
|
||||||
"""Load model from file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): file path
|
|
||||||
"""
|
|
||||||
parameters = torch.load(path, map_location=torch.device('cpu'))
|
|
||||||
|
|
||||||
if "optimizer" in parameters:
|
|
||||||
parameters = parameters["model"]
|
|
||||||
|
|
||||||
self.load_state_dict(parameters)
|
|
||||||
@ -1,342 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .vit import (
|
|
||||||
_make_pretrained_vitb_rn50_384,
|
|
||||||
_make_pretrained_vitl16_384,
|
|
||||||
_make_pretrained_vitb16_384,
|
|
||||||
forward_vit,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
|
||||||
if backbone == "vitl16_384":
|
|
||||||
pretrained = _make_pretrained_vitl16_384(
|
|
||||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
|
||||||
)
|
|
||||||
scratch = _make_scratch(
|
|
||||||
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
|
||||||
) # ViT-L/16 - 85.0% Top1 (backbone)
|
|
||||||
elif backbone == "vitb_rn50_384":
|
|
||||||
pretrained = _make_pretrained_vitb_rn50_384(
|
|
||||||
use_pretrained,
|
|
||||||
hooks=hooks,
|
|
||||||
use_vit_only=use_vit_only,
|
|
||||||
use_readout=use_readout,
|
|
||||||
)
|
|
||||||
scratch = _make_scratch(
|
|
||||||
[256, 512, 768, 768], features, groups=groups, expand=expand
|
|
||||||
) # ViT-H/16 - 85.0% Top1 (backbone)
|
|
||||||
elif backbone == "vitb16_384":
|
|
||||||
pretrained = _make_pretrained_vitb16_384(
|
|
||||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
|
||||||
)
|
|
||||||
scratch = _make_scratch(
|
|
||||||
[96, 192, 384, 768], features, groups=groups, expand=expand
|
|
||||||
) # ViT-B/16 - 84.6% Top1 (backbone)
|
|
||||||
elif backbone == "resnext101_wsl":
|
|
||||||
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
|
||||||
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
|
||||||
elif backbone == "efficientnet_lite3":
|
|
||||||
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
|
||||||
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
|
||||||
else:
|
|
||||||
print(f"Backbone '{backbone}' not implemented")
|
|
||||||
assert False
|
|
||||||
|
|
||||||
return pretrained, scratch
|
|
||||||
|
|
||||||
|
|
||||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
|
||||||
scratch = nn.Module()
|
|
||||||
|
|
||||||
out_shape1 = out_shape
|
|
||||||
out_shape2 = out_shape
|
|
||||||
out_shape3 = out_shape
|
|
||||||
out_shape4 = out_shape
|
|
||||||
if expand==True:
|
|
||||||
out_shape1 = out_shape
|
|
||||||
out_shape2 = out_shape*2
|
|
||||||
out_shape3 = out_shape*4
|
|
||||||
out_shape4 = out_shape*8
|
|
||||||
|
|
||||||
scratch.layer1_rn = nn.Conv2d(
|
|
||||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
scratch.layer2_rn = nn.Conv2d(
|
|
||||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
scratch.layer3_rn = nn.Conv2d(
|
|
||||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
scratch.layer4_rn = nn.Conv2d(
|
|
||||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
|
|
||||||
return scratch
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
|
||||||
efficientnet = torch.hub.load(
|
|
||||||
"rwightman/gen-efficientnet-pytorch",
|
|
||||||
"tf_efficientnet_lite3",
|
|
||||||
pretrained=use_pretrained,
|
|
||||||
exportable=exportable
|
|
||||||
)
|
|
||||||
return _make_efficientnet_backbone(efficientnet)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_efficientnet_backbone(effnet):
|
|
||||||
pretrained = nn.Module()
|
|
||||||
|
|
||||||
pretrained.layer1 = nn.Sequential(
|
|
||||||
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
|
||||||
)
|
|
||||||
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
|
||||||
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
|
||||||
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
|
||||||
|
|
||||||
return pretrained
|
|
||||||
|
|
||||||
|
|
||||||
def _make_resnet_backbone(resnet):
|
|
||||||
pretrained = nn.Module()
|
|
||||||
pretrained.layer1 = nn.Sequential(
|
|
||||||
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.layer2 = resnet.layer2
|
|
||||||
pretrained.layer3 = resnet.layer3
|
|
||||||
pretrained.layer4 = resnet.layer4
|
|
||||||
|
|
||||||
return pretrained
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_resnext101_wsl(use_pretrained):
|
|
||||||
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
|
||||||
return _make_resnet_backbone(resnet)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Interpolate(nn.Module):
|
|
||||||
"""Interpolation module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, scale_factor, mode, align_corners=False):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scale_factor (float): scaling
|
|
||||||
mode (str): interpolation mode
|
|
||||||
"""
|
|
||||||
super(Interpolate, self).__init__()
|
|
||||||
|
|
||||||
self.interp = nn.functional.interpolate
|
|
||||||
self.scale_factor = scale_factor
|
|
||||||
self.mode = mode
|
|
||||||
self.align_corners = align_corners
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: interpolated data
|
|
||||||
"""
|
|
||||||
|
|
||||||
x = self.interp(
|
|
||||||
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
|
||||||
)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualConvUnit(nn.Module):
|
|
||||||
"""Residual convolution module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, features):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
out = self.relu(x)
|
|
||||||
out = self.conv1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
out = self.conv2(out)
|
|
||||||
|
|
||||||
return out + x
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureFusionBlock(nn.Module):
|
|
||||||
"""Feature fusion block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, features):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super(FeatureFusionBlock, self).__init__()
|
|
||||||
|
|
||||||
self.resConfUnit1 = ResidualConvUnit(features)
|
|
||||||
self.resConfUnit2 = ResidualConvUnit(features)
|
|
||||||
|
|
||||||
def forward(self, *xs):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
output = xs[0]
|
|
||||||
|
|
||||||
if len(xs) == 2:
|
|
||||||
output += self.resConfUnit1(xs[1])
|
|
||||||
|
|
||||||
output = self.resConfUnit2(output)
|
|
||||||
|
|
||||||
output = nn.functional.interpolate(
|
|
||||||
output, scale_factor=2, mode="bilinear", align_corners=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualConvUnit_custom(nn.Module):
|
|
||||||
"""Residual convolution module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, features, activation, bn):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.bn = bn
|
|
||||||
|
|
||||||
self.groups=1
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.bn==True:
|
|
||||||
self.bn1 = nn.BatchNorm2d(features)
|
|
||||||
self.bn2 = nn.BatchNorm2d(features)
|
|
||||||
|
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
self.skip_add = nn.quantized.FloatFunctional()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
|
|
||||||
out = self.activation(x)
|
|
||||||
out = self.conv1(out)
|
|
||||||
if self.bn==True:
|
|
||||||
out = self.bn1(out)
|
|
||||||
|
|
||||||
out = self.activation(out)
|
|
||||||
out = self.conv2(out)
|
|
||||||
if self.bn==True:
|
|
||||||
out = self.bn2(out)
|
|
||||||
|
|
||||||
if self.groups > 1:
|
|
||||||
out = self.conv_merge(out)
|
|
||||||
|
|
||||||
return self.skip_add.add(out, x)
|
|
||||||
|
|
||||||
# return out + x
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureFusionBlock_custom(nn.Module):
|
|
||||||
"""Feature fusion block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super(FeatureFusionBlock_custom, self).__init__()
|
|
||||||
|
|
||||||
self.deconv = deconv
|
|
||||||
self.align_corners = align_corners
|
|
||||||
|
|
||||||
self.groups=1
|
|
||||||
|
|
||||||
self.expand = expand
|
|
||||||
out_features = features
|
|
||||||
if self.expand==True:
|
|
||||||
out_features = features//2
|
|
||||||
|
|
||||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
|
||||||
|
|
||||||
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
|
||||||
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
|
||||||
|
|
||||||
self.skip_add = nn.quantized.FloatFunctional()
|
|
||||||
|
|
||||||
def forward(self, *xs):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
output = xs[0]
|
|
||||||
|
|
||||||
if len(xs) == 2:
|
|
||||||
res = self.resConfUnit1(xs[1])
|
|
||||||
output = self.skip_add.add(output, res)
|
|
||||||
# output += res
|
|
||||||
|
|
||||||
output = self.resConfUnit2(output)
|
|
||||||
|
|
||||||
output = nn.functional.interpolate(
|
|
||||||
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
|
||||||
)
|
|
||||||
|
|
||||||
output = self.out_conv(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@ -1,109 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .base_model import BaseModel
|
|
||||||
from .blocks import (
|
|
||||||
FeatureFusionBlock,
|
|
||||||
FeatureFusionBlock_custom,
|
|
||||||
Interpolate,
|
|
||||||
_make_encoder,
|
|
||||||
forward_vit,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_fusion_block(features, use_bn):
|
|
||||||
return FeatureFusionBlock_custom(
|
|
||||||
features,
|
|
||||||
nn.ReLU(False),
|
|
||||||
deconv=False,
|
|
||||||
bn=use_bn,
|
|
||||||
expand=False,
|
|
||||||
align_corners=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DPT(BaseModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
head,
|
|
||||||
features=256,
|
|
||||||
backbone="vitb_rn50_384",
|
|
||||||
readout="project",
|
|
||||||
channels_last=False,
|
|
||||||
use_bn=False,
|
|
||||||
):
|
|
||||||
|
|
||||||
super(DPT, self).__init__()
|
|
||||||
|
|
||||||
self.channels_last = channels_last
|
|
||||||
|
|
||||||
hooks = {
|
|
||||||
"vitb_rn50_384": [0, 1, 8, 11],
|
|
||||||
"vitb16_384": [2, 5, 8, 11],
|
|
||||||
"vitl16_384": [5, 11, 17, 23],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Instantiate backbone and reassemble blocks
|
|
||||||
self.pretrained, self.scratch = _make_encoder(
|
|
||||||
backbone,
|
|
||||||
features,
|
|
||||||
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
|
||||||
groups=1,
|
|
||||||
expand=False,
|
|
||||||
exportable=False,
|
|
||||||
hooks=hooks[backbone],
|
|
||||||
use_readout=readout,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
|
||||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
|
||||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
|
||||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
|
||||||
|
|
||||||
self.scratch.output_conv = head
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.channels_last == True:
|
|
||||||
x.contiguous(memory_format=torch.channels_last)
|
|
||||||
|
|
||||||
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
|
||||||
|
|
||||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
|
||||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
|
||||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
|
||||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
|
||||||
|
|
||||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
|
||||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
|
||||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
|
||||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
|
||||||
|
|
||||||
out = self.scratch.output_conv(path_1)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DPTDepthModel(DPT):
|
|
||||||
def __init__(self, path=None, non_negative=True, **kwargs):
|
|
||||||
features = kwargs["features"] if "features" in kwargs else 256
|
|
||||||
|
|
||||||
head = nn.Sequential(
|
|
||||||
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
|
||||||
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
|
||||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
|
||||||
nn.Identity(),
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(head, **kwargs)
|
|
||||||
|
|
||||||
if path is not None:
|
|
||||||
self.load(path)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return super().forward(x).squeeze(dim=1)
|
|
||||||
|
|
||||||
@ -1,76 +0,0 @@
|
|||||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
|
||||||
This file contains code that is adapted from
|
|
||||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .base_model import BaseModel
|
|
||||||
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
|
||||||
|
|
||||||
|
|
||||||
class MidasNet(BaseModel):
|
|
||||||
"""Network for monocular depth estimation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path=None, features=256, non_negative=True):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str, optional): Path to saved model. Defaults to None.
|
|
||||||
features (int, optional): Number of features. Defaults to 256.
|
|
||||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
|
||||||
"""
|
|
||||||
print("Loading weights: ", path)
|
|
||||||
|
|
||||||
super(MidasNet, self).__init__()
|
|
||||||
|
|
||||||
use_pretrained = False if path is None else True
|
|
||||||
|
|
||||||
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
|
||||||
|
|
||||||
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
|
||||||
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
|
||||||
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
|
||||||
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
|
||||||
|
|
||||||
self.scratch.output_conv = nn.Sequential(
|
|
||||||
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
|
||||||
Interpolate(scale_factor=2, mode="bilinear"),
|
|
||||||
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if path:
|
|
||||||
self.load(path)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input data (image)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: depth
|
|
||||||
"""
|
|
||||||
|
|
||||||
layer_1 = self.pretrained.layer1(x)
|
|
||||||
layer_2 = self.pretrained.layer2(layer_1)
|
|
||||||
layer_3 = self.pretrained.layer3(layer_2)
|
|
||||||
layer_4 = self.pretrained.layer4(layer_3)
|
|
||||||
|
|
||||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
|
||||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
|
||||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
|
||||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
|
||||||
|
|
||||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
|
||||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
|
||||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
|
||||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
|
||||||
|
|
||||||
out = self.scratch.output_conv(path_1)
|
|
||||||
|
|
||||||
return torch.squeeze(out, dim=1)
|
|
||||||
@ -1,128 +0,0 @@
|
|||||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
|
||||||
This file contains code that is adapted from
|
|
||||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .base_model import BaseModel
|
|
||||||
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
|
||||||
|
|
||||||
|
|
||||||
class MidasNet_small(BaseModel):
|
|
||||||
"""Network for monocular depth estimation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
|
||||||
blocks={'expand': True}):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str, optional): Path to saved model. Defaults to None.
|
|
||||||
features (int, optional): Number of features. Defaults to 256.
|
|
||||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
|
||||||
"""
|
|
||||||
print("Loading weights: ", path)
|
|
||||||
|
|
||||||
super(MidasNet_small, self).__init__()
|
|
||||||
|
|
||||||
use_pretrained = False if path else True
|
|
||||||
|
|
||||||
self.channels_last = channels_last
|
|
||||||
self.blocks = blocks
|
|
||||||
self.backbone = backbone
|
|
||||||
|
|
||||||
self.groups = 1
|
|
||||||
|
|
||||||
features1=features
|
|
||||||
features2=features
|
|
||||||
features3=features
|
|
||||||
features4=features
|
|
||||||
self.expand = False
|
|
||||||
if "expand" in self.blocks and self.blocks['expand'] == True:
|
|
||||||
self.expand = True
|
|
||||||
features1=features
|
|
||||||
features2=features*2
|
|
||||||
features3=features*4
|
|
||||||
features4=features*8
|
|
||||||
|
|
||||||
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
|
||||||
|
|
||||||
self.scratch.activation = nn.ReLU(False)
|
|
||||||
|
|
||||||
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
|
||||||
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
|
||||||
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
|
||||||
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
|
||||||
|
|
||||||
|
|
||||||
self.scratch.output_conv = nn.Sequential(
|
|
||||||
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
|
||||||
Interpolate(scale_factor=2, mode="bilinear"),
|
|
||||||
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
|
||||||
self.scratch.activation,
|
|
||||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
|
||||||
nn.Identity(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if path:
|
|
||||||
self.load(path)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input data (image)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: depth
|
|
||||||
"""
|
|
||||||
if self.channels_last==True:
|
|
||||||
print("self.channels_last = ", self.channels_last)
|
|
||||||
x.contiguous(memory_format=torch.channels_last)
|
|
||||||
|
|
||||||
|
|
||||||
layer_1 = self.pretrained.layer1(x)
|
|
||||||
layer_2 = self.pretrained.layer2(layer_1)
|
|
||||||
layer_3 = self.pretrained.layer3(layer_2)
|
|
||||||
layer_4 = self.pretrained.layer4(layer_3)
|
|
||||||
|
|
||||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
|
||||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
|
||||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
|
||||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
|
||||||
|
|
||||||
|
|
||||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
|
||||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
|
||||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
|
||||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
|
||||||
|
|
||||||
out = self.scratch.output_conv(path_1)
|
|
||||||
|
|
||||||
return torch.squeeze(out, dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def fuse_model(m):
|
|
||||||
prev_previous_type = nn.Identity()
|
|
||||||
prev_previous_name = ''
|
|
||||||
previous_type = nn.Identity()
|
|
||||||
previous_name = ''
|
|
||||||
for name, module in m.named_modules():
|
|
||||||
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
|
||||||
# print("FUSED ", prev_previous_name, previous_name, name)
|
|
||||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
|
||||||
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
|
||||||
# print("FUSED ", prev_previous_name, previous_name)
|
|
||||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
|
||||||
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
|
||||||
# print("FUSED ", previous_name, name)
|
|
||||||
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
|
||||||
|
|
||||||
prev_previous_type = previous_type
|
|
||||||
prev_previous_name = previous_name
|
|
||||||
previous_type = type(module)
|
|
||||||
previous_name = name
|
|
||||||
@ -1,234 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
|
||||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample (dict): sample
|
|
||||||
size (tuple): image size
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: new size
|
|
||||||
"""
|
|
||||||
shape = list(sample["disparity"].shape)
|
|
||||||
|
|
||||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
|
||||||
return sample
|
|
||||||
|
|
||||||
scale = [0, 0]
|
|
||||||
scale[0] = size[0] / shape[0]
|
|
||||||
scale[1] = size[1] / shape[1]
|
|
||||||
|
|
||||||
scale = max(scale)
|
|
||||||
|
|
||||||
shape[0] = math.ceil(scale * shape[0])
|
|
||||||
shape[1] = math.ceil(scale * shape[1])
|
|
||||||
|
|
||||||
# resize
|
|
||||||
sample["image"] = cv2.resize(
|
|
||||||
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
|
||||||
)
|
|
||||||
|
|
||||||
sample["disparity"] = cv2.resize(
|
|
||||||
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
|
||||||
)
|
|
||||||
sample["mask"] = cv2.resize(
|
|
||||||
sample["mask"].astype(np.float32),
|
|
||||||
tuple(shape[::-1]),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
sample["mask"] = sample["mask"].astype(bool)
|
|
||||||
|
|
||||||
return tuple(shape)
|
|
||||||
|
|
||||||
|
|
||||||
class Resize(object):
|
|
||||||
"""Resize sample to given size (width, height).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
resize_target=True,
|
|
||||||
keep_aspect_ratio=False,
|
|
||||||
ensure_multiple_of=1,
|
|
||||||
resize_method="lower_bound",
|
|
||||||
image_interpolation_method=cv2.INTER_AREA,
|
|
||||||
):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
width (int): desired output width
|
|
||||||
height (int): desired output height
|
|
||||||
resize_target (bool, optional):
|
|
||||||
True: Resize the full sample (image, mask, target).
|
|
||||||
False: Resize image only.
|
|
||||||
Defaults to True.
|
|
||||||
keep_aspect_ratio (bool, optional):
|
|
||||||
True: Keep the aspect ratio of the input sample.
|
|
||||||
Output sample might not have the given width and height, and
|
|
||||||
resize behaviour depends on the parameter 'resize_method'.
|
|
||||||
Defaults to False.
|
|
||||||
ensure_multiple_of (int, optional):
|
|
||||||
Output width and height is constrained to be multiple of this parameter.
|
|
||||||
Defaults to 1.
|
|
||||||
resize_method (str, optional):
|
|
||||||
"lower_bound": Output will be at least as large as the given size.
|
|
||||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
|
||||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
|
||||||
Defaults to "lower_bound".
|
|
||||||
"""
|
|
||||||
self.__width = width
|
|
||||||
self.__height = height
|
|
||||||
|
|
||||||
self.__resize_target = resize_target
|
|
||||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
|
||||||
self.__multiple_of = ensure_multiple_of
|
|
||||||
self.__resize_method = resize_method
|
|
||||||
self.__image_interpolation_method = image_interpolation_method
|
|
||||||
|
|
||||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
|
||||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
|
||||||
|
|
||||||
if max_val is not None and y > max_val:
|
|
||||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
|
||||||
|
|
||||||
if y < min_val:
|
|
||||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
def get_size(self, width, height):
|
|
||||||
# determine new height and width
|
|
||||||
scale_height = self.__height / height
|
|
||||||
scale_width = self.__width / width
|
|
||||||
|
|
||||||
if self.__keep_aspect_ratio:
|
|
||||||
if self.__resize_method == "lower_bound":
|
|
||||||
# scale such that output size is lower bound
|
|
||||||
if scale_width > scale_height:
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
elif self.__resize_method == "upper_bound":
|
|
||||||
# scale such that output size is upper bound
|
|
||||||
if scale_width < scale_height:
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
elif self.__resize_method == "minimal":
|
|
||||||
# scale as least as possbile
|
|
||||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"resize_method {self.__resize_method} not implemented"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.__resize_method == "lower_bound":
|
|
||||||
new_height = self.constrain_to_multiple_of(
|
|
||||||
scale_height * height, min_val=self.__height
|
|
||||||
)
|
|
||||||
new_width = self.constrain_to_multiple_of(
|
|
||||||
scale_width * width, min_val=self.__width
|
|
||||||
)
|
|
||||||
elif self.__resize_method == "upper_bound":
|
|
||||||
new_height = self.constrain_to_multiple_of(
|
|
||||||
scale_height * height, max_val=self.__height
|
|
||||||
)
|
|
||||||
new_width = self.constrain_to_multiple_of(
|
|
||||||
scale_width * width, max_val=self.__width
|
|
||||||
)
|
|
||||||
elif self.__resize_method == "minimal":
|
|
||||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
|
||||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
|
||||||
|
|
||||||
return (new_width, new_height)
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
width, height = self.get_size(
|
|
||||||
sample["image"].shape[1], sample["image"].shape[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# resize sample
|
|
||||||
sample["image"] = cv2.resize(
|
|
||||||
sample["image"],
|
|
||||||
(width, height),
|
|
||||||
interpolation=self.__image_interpolation_method,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.__resize_target:
|
|
||||||
if "disparity" in sample:
|
|
||||||
sample["disparity"] = cv2.resize(
|
|
||||||
sample["disparity"],
|
|
||||||
(width, height),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
|
|
||||||
if "depth" in sample:
|
|
||||||
sample["depth"] = cv2.resize(
|
|
||||||
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
|
||||||
)
|
|
||||||
|
|
||||||
sample["mask"] = cv2.resize(
|
|
||||||
sample["mask"].astype(np.float32),
|
|
||||||
(width, height),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
sample["mask"] = sample["mask"].astype(bool)
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizeImage(object):
|
|
||||||
"""Normlize image by given mean and std.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, mean, std):
|
|
||||||
self.__mean = mean
|
|
||||||
self.__std = std
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
class PrepareForNet(object):
|
|
||||||
"""Prepare sample for usage as network input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
image = np.transpose(sample["image"], (2, 0, 1))
|
|
||||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
|
||||||
|
|
||||||
if "mask" in sample:
|
|
||||||
sample["mask"] = sample["mask"].astype(np.float32)
|
|
||||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
|
||||||
|
|
||||||
if "disparity" in sample:
|
|
||||||
disparity = sample["disparity"].astype(np.float32)
|
|
||||||
sample["disparity"] = np.ascontiguousarray(disparity)
|
|
||||||
|
|
||||||
if "depth" in sample:
|
|
||||||
depth = sample["depth"].astype(np.float32)
|
|
||||||
sample["depth"] = np.ascontiguousarray(depth)
|
|
||||||
|
|
||||||
return sample
|
|
||||||
@ -1,491 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import timm
|
|
||||||
import types
|
|
||||||
import math
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Slice(nn.Module):
|
|
||||||
def __init__(self, start_index=1):
|
|
||||||
super(Slice, self).__init__()
|
|
||||||
self.start_index = start_index
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x[:, self.start_index :]
|
|
||||||
|
|
||||||
|
|
||||||
class AddReadout(nn.Module):
|
|
||||||
def __init__(self, start_index=1):
|
|
||||||
super(AddReadout, self).__init__()
|
|
||||||
self.start_index = start_index
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.start_index == 2:
|
|
||||||
readout = (x[:, 0] + x[:, 1]) / 2
|
|
||||||
else:
|
|
||||||
readout = x[:, 0]
|
|
||||||
return x[:, self.start_index :] + readout.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectReadout(nn.Module):
|
|
||||||
def __init__(self, in_features, start_index=1):
|
|
||||||
super(ProjectReadout, self).__init__()
|
|
||||||
self.start_index = start_index
|
|
||||||
|
|
||||||
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
|
||||||
features = torch.cat((x[:, self.start_index :], readout), -1)
|
|
||||||
|
|
||||||
return self.project(features)
|
|
||||||
|
|
||||||
|
|
||||||
class Transpose(nn.Module):
|
|
||||||
def __init__(self, dim0, dim1):
|
|
||||||
super(Transpose, self).__init__()
|
|
||||||
self.dim0 = dim0
|
|
||||||
self.dim1 = dim1
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.transpose(self.dim0, self.dim1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def forward_vit(pretrained, x):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
|
|
||||||
glob = pretrained.model.forward_flex(x)
|
|
||||||
|
|
||||||
layer_1 = pretrained.activations["1"]
|
|
||||||
layer_2 = pretrained.activations["2"]
|
|
||||||
layer_3 = pretrained.activations["3"]
|
|
||||||
layer_4 = pretrained.activations["4"]
|
|
||||||
|
|
||||||
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
|
||||||
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
|
||||||
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
|
||||||
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
|
||||||
|
|
||||||
unflatten = nn.Sequential(
|
|
||||||
nn.Unflatten(
|
|
||||||
2,
|
|
||||||
torch.Size(
|
|
||||||
[
|
|
||||||
h // pretrained.model.patch_size[1],
|
|
||||||
w // pretrained.model.patch_size[0],
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer_1.ndim == 3:
|
|
||||||
layer_1 = unflatten(layer_1)
|
|
||||||
if layer_2.ndim == 3:
|
|
||||||
layer_2 = unflatten(layer_2)
|
|
||||||
if layer_3.ndim == 3:
|
|
||||||
layer_3 = unflatten(layer_3)
|
|
||||||
if layer_4.ndim == 3:
|
|
||||||
layer_4 = unflatten(layer_4)
|
|
||||||
|
|
||||||
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
|
||||||
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
|
||||||
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
|
||||||
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
|
||||||
|
|
||||||
return layer_1, layer_2, layer_3, layer_4
|
|
||||||
|
|
||||||
|
|
||||||
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
|
||||||
posemb_tok, posemb_grid = (
|
|
||||||
posemb[:, : self.start_index],
|
|
||||||
posemb[0, self.start_index :],
|
|
||||||
)
|
|
||||||
|
|
||||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
||||||
|
|
||||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
||||||
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
|
||||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
|
||||||
|
|
||||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
||||||
|
|
||||||
return posemb
|
|
||||||
|
|
||||||
|
|
||||||
def forward_flex(self, x):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
|
|
||||||
pos_embed = self._resize_pos_embed(
|
|
||||||
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
B = x.shape[0]
|
|
||||||
|
|
||||||
if hasattr(self.patch_embed, "backbone"):
|
|
||||||
x = self.patch_embed.backbone(x)
|
|
||||||
if isinstance(x, (list, tuple)):
|
|
||||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
||||||
|
|
||||||
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
if getattr(self, "dist_token", None) is not None:
|
|
||||||
cls_tokens = self.cls_token.expand(
|
|
||||||
B, -1, -1
|
|
||||||
) # stole cls_tokens impl from Phil Wang, thanks
|
|
||||||
dist_token = self.dist_token.expand(B, -1, -1)
|
|
||||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
|
||||||
else:
|
|
||||||
cls_tokens = self.cls_token.expand(
|
|
||||||
B, -1, -1
|
|
||||||
) # stole cls_tokens impl from Phil Wang, thanks
|
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
|
||||||
|
|
||||||
x = x + pos_embed
|
|
||||||
x = self.pos_drop(x)
|
|
||||||
|
|
||||||
for blk in self.blocks:
|
|
||||||
x = blk(x)
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
activations = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_activation(name):
|
|
||||||
def hook(model, input, output):
|
|
||||||
activations[name] = output
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
|
||||||
if use_readout == "ignore":
|
|
||||||
readout_oper = [Slice(start_index)] * len(features)
|
|
||||||
elif use_readout == "add":
|
|
||||||
readout_oper = [AddReadout(start_index)] * len(features)
|
|
||||||
elif use_readout == "project":
|
|
||||||
readout_oper = [
|
|
||||||
ProjectReadout(vit_features, start_index) for out_feat in features
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
False
|
|
||||||
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
|
||||||
|
|
||||||
return readout_oper
|
|
||||||
|
|
||||||
|
|
||||||
def _make_vit_b16_backbone(
|
|
||||||
model,
|
|
||||||
features=[96, 192, 384, 768],
|
|
||||||
size=[384, 384],
|
|
||||||
hooks=[2, 5, 8, 11],
|
|
||||||
vit_features=768,
|
|
||||||
use_readout="ignore",
|
|
||||||
start_index=1,
|
|
||||||
):
|
|
||||||
pretrained = nn.Module()
|
|
||||||
|
|
||||||
pretrained.model = model
|
|
||||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
|
||||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
|
||||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
|
||||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
|
||||||
|
|
||||||
pretrained.activations = activations
|
|
||||||
|
|
||||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
|
||||||
|
|
||||||
# 32, 48, 136, 384
|
|
||||||
pretrained.act_postprocess1 = nn.Sequential(
|
|
||||||
readout_oper[0],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[0],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=features[0],
|
|
||||||
out_channels=features[0],
|
|
||||||
kernel_size=4,
|
|
||||||
stride=4,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess2 = nn.Sequential(
|
|
||||||
readout_oper[1],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[1],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=features[1],
|
|
||||||
out_channels=features[1],
|
|
||||||
kernel_size=2,
|
|
||||||
stride=2,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess3 = nn.Sequential(
|
|
||||||
readout_oper[2],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[2],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess4 = nn.Sequential(
|
|
||||||
readout_oper[3],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[3],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=features[3],
|
|
||||||
out_channels=features[3],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.model.start_index = start_index
|
|
||||||
pretrained.model.patch_size = [16, 16]
|
|
||||||
|
|
||||||
# We inject this function into the VisionTransformer instances so that
|
|
||||||
# we can use it with interpolated position embeddings without modifying the library source.
|
|
||||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
|
||||||
pretrained.model._resize_pos_embed = types.MethodType(
|
|
||||||
_resize_pos_embed, pretrained.model
|
|
||||||
)
|
|
||||||
|
|
||||||
return pretrained
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
|
||||||
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
|
||||||
|
|
||||||
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
|
||||||
return _make_vit_b16_backbone(
|
|
||||||
model,
|
|
||||||
features=[256, 512, 1024, 1024],
|
|
||||||
hooks=hooks,
|
|
||||||
vit_features=1024,
|
|
||||||
use_readout=use_readout,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
|
||||||
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
|
||||||
|
|
||||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
|
||||||
return _make_vit_b16_backbone(
|
|
||||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
|
||||||
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
|
||||||
|
|
||||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
|
||||||
return _make_vit_b16_backbone(
|
|
||||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
|
||||||
model = timm.create_model(
|
|
||||||
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
|
||||||
)
|
|
||||||
|
|
||||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
|
||||||
return _make_vit_b16_backbone(
|
|
||||||
model,
|
|
||||||
features=[96, 192, 384, 768],
|
|
||||||
hooks=hooks,
|
|
||||||
use_readout=use_readout,
|
|
||||||
start_index=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_vit_b_rn50_backbone(
|
|
||||||
model,
|
|
||||||
features=[256, 512, 768, 768],
|
|
||||||
size=[384, 384],
|
|
||||||
hooks=[0, 1, 8, 11],
|
|
||||||
vit_features=768,
|
|
||||||
use_vit_only=False,
|
|
||||||
use_readout="ignore",
|
|
||||||
start_index=1,
|
|
||||||
):
|
|
||||||
pretrained = nn.Module()
|
|
||||||
|
|
||||||
pretrained.model = model
|
|
||||||
|
|
||||||
if use_vit_only == True:
|
|
||||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
|
||||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
|
||||||
else:
|
|
||||||
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
|
||||||
get_activation("1")
|
|
||||||
)
|
|
||||||
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
|
||||||
get_activation("2")
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
|
||||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
|
||||||
|
|
||||||
pretrained.activations = activations
|
|
||||||
|
|
||||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
|
||||||
|
|
||||||
if use_vit_only == True:
|
|
||||||
pretrained.act_postprocess1 = nn.Sequential(
|
|
||||||
readout_oper[0],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[0],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=features[0],
|
|
||||||
out_channels=features[0],
|
|
||||||
kernel_size=4,
|
|
||||||
stride=4,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess2 = nn.Sequential(
|
|
||||||
readout_oper[1],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[1],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=features[1],
|
|
||||||
out_channels=features[1],
|
|
||||||
kernel_size=2,
|
|
||||||
stride=2,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pretrained.act_postprocess1 = nn.Sequential(
|
|
||||||
nn.Identity(), nn.Identity(), nn.Identity()
|
|
||||||
)
|
|
||||||
pretrained.act_postprocess2 = nn.Sequential(
|
|
||||||
nn.Identity(), nn.Identity(), nn.Identity()
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess3 = nn.Sequential(
|
|
||||||
readout_oper[2],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[2],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.act_postprocess4 = nn.Sequential(
|
|
||||||
readout_oper[3],
|
|
||||||
Transpose(1, 2),
|
|
||||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=vit_features,
|
|
||||||
out_channels=features[3],
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=features[3],
|
|
||||||
out_channels=features[3],
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained.model.start_index = start_index
|
|
||||||
pretrained.model.patch_size = [16, 16]
|
|
||||||
|
|
||||||
# We inject this function into the VisionTransformer instances so that
|
|
||||||
# we can use it with interpolated position embeddings without modifying the library source.
|
|
||||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
|
||||||
|
|
||||||
# We inject this function into the VisionTransformer instances so that
|
|
||||||
# we can use it with interpolated position embeddings without modifying the library source.
|
|
||||||
pretrained.model._resize_pos_embed = types.MethodType(
|
|
||||||
_resize_pos_embed, pretrained.model
|
|
||||||
)
|
|
||||||
|
|
||||||
return pretrained
|
|
||||||
|
|
||||||
|
|
||||||
def _make_pretrained_vitb_rn50_384(
|
|
||||||
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
|
||||||
):
|
|
||||||
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
|
||||||
|
|
||||||
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
|
||||||
return _make_vit_b_rn50_backbone(
|
|
||||||
model,
|
|
||||||
features=[256, 512, 768, 768],
|
|
||||||
size=[384, 384],
|
|
||||||
hooks=hooks,
|
|
||||||
use_vit_only=use_vit_only,
|
|
||||||
use_readout=use_readout,
|
|
||||||
)
|
|
||||||
@ -1,189 +0,0 @@
|
|||||||
"""Utils for monoDepth."""
|
|
||||||
import sys
|
|
||||||
import re
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def read_pfm(path):
|
|
||||||
"""Read pfm file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): path to file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (data, scale)
|
|
||||||
"""
|
|
||||||
with open(path, "rb") as file:
|
|
||||||
|
|
||||||
color = None
|
|
||||||
width = None
|
|
||||||
height = None
|
|
||||||
scale = None
|
|
||||||
endian = None
|
|
||||||
|
|
||||||
header = file.readline().rstrip()
|
|
||||||
if header.decode("ascii") == "PF":
|
|
||||||
color = True
|
|
||||||
elif header.decode("ascii") == "Pf":
|
|
||||||
color = False
|
|
||||||
else:
|
|
||||||
raise Exception("Not a PFM file: " + path)
|
|
||||||
|
|
||||||
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
|
||||||
if dim_match:
|
|
||||||
width, height = list(map(int, dim_match.groups()))
|
|
||||||
else:
|
|
||||||
raise Exception("Malformed PFM header.")
|
|
||||||
|
|
||||||
scale = float(file.readline().decode("ascii").rstrip())
|
|
||||||
if scale < 0:
|
|
||||||
# little-endian
|
|
||||||
endian = "<"
|
|
||||||
scale = -scale
|
|
||||||
else:
|
|
||||||
# big-endian
|
|
||||||
endian = ">"
|
|
||||||
|
|
||||||
data = np.fromfile(file, endian + "f")
|
|
||||||
shape = (height, width, 3) if color else (height, width)
|
|
||||||
|
|
||||||
data = np.reshape(data, shape)
|
|
||||||
data = np.flipud(data)
|
|
||||||
|
|
||||||
return data, scale
|
|
||||||
|
|
||||||
|
|
||||||
def write_pfm(path, image, scale=1):
|
|
||||||
"""Write pfm file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): pathto file
|
|
||||||
image (array): data
|
|
||||||
scale (int, optional): Scale. Defaults to 1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
with open(path, "wb") as file:
|
|
||||||
color = None
|
|
||||||
|
|
||||||
if image.dtype.name != "float32":
|
|
||||||
raise Exception("Image dtype must be float32.")
|
|
||||||
|
|
||||||
image = np.flipud(image)
|
|
||||||
|
|
||||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
|
||||||
color = True
|
|
||||||
elif (
|
|
||||||
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
|
||||||
): # greyscale
|
|
||||||
color = False
|
|
||||||
else:
|
|
||||||
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
|
||||||
|
|
||||||
file.write("PF\n" if color else "Pf\n".encode())
|
|
||||||
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
|
||||||
|
|
||||||
endian = image.dtype.byteorder
|
|
||||||
|
|
||||||
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
|
||||||
scale = -scale
|
|
||||||
|
|
||||||
file.write("%f\n".encode() % scale)
|
|
||||||
|
|
||||||
image.tofile(file)
|
|
||||||
|
|
||||||
|
|
||||||
def read_image(path):
|
|
||||||
"""Read image and output RGB image (0-1).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): path to file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
array: RGB image (0-1)
|
|
||||||
"""
|
|
||||||
img = cv2.imread(path)
|
|
||||||
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
||||||
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def resize_image(img):
|
|
||||||
"""Resize image and make it fit for network.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img (array): image
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: data ready for network
|
|
||||||
"""
|
|
||||||
height_orig = img.shape[0]
|
|
||||||
width_orig = img.shape[1]
|
|
||||||
|
|
||||||
if width_orig > height_orig:
|
|
||||||
scale = width_orig / 384
|
|
||||||
else:
|
|
||||||
scale = height_orig / 384
|
|
||||||
|
|
||||||
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
|
||||||
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
|
||||||
|
|
||||||
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
|
||||||
|
|
||||||
img_resized = (
|
|
||||||
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
|
||||||
)
|
|
||||||
img_resized = img_resized.unsqueeze(0)
|
|
||||||
|
|
||||||
return img_resized
|
|
||||||
|
|
||||||
|
|
||||||
def resize_depth(depth, width, height):
|
|
||||||
"""Resize depth map and bring to CPU (numpy).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
depth (tensor): depth
|
|
||||||
width (int): image width
|
|
||||||
height (int): image height
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
array: processed depth
|
|
||||||
"""
|
|
||||||
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
|
||||||
|
|
||||||
depth_resized = cv2.resize(
|
|
||||||
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
|
||||||
)
|
|
||||||
|
|
||||||
return depth_resized
|
|
||||||
|
|
||||||
def write_depth(path, depth, bits=1):
|
|
||||||
"""Write depth map to pfm and png file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): filepath without extension
|
|
||||||
depth (array): depth
|
|
||||||
"""
|
|
||||||
write_pfm(path + ".pfm", depth.astype(np.float32))
|
|
||||||
|
|
||||||
depth_min = depth.min()
|
|
||||||
depth_max = depth.max()
|
|
||||||
|
|
||||||
max_val = (2**(8*bits))-1
|
|
||||||
|
|
||||||
if depth_max - depth_min > np.finfo("float").eps:
|
|
||||||
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
|
||||||
else:
|
|
||||||
out = np.zeros(depth.shape, dtype=depth.type)
|
|
||||||
|
|
||||||
if bits == 1:
|
|
||||||
cv2.imwrite(path + ".png", out.astype("uint8"))
|
|
||||||
elif bits == 2:
|
|
||||||
cv2.imwrite(path + ".png", out.astype("uint16"))
|
|
||||||
|
|
||||||
return
|
|
||||||
@ -151,7 +151,7 @@ if args.lowvram:
|
|||||||
lowvram_available = True
|
lowvram_available = True
|
||||||
elif args.novram:
|
elif args.novram:
|
||||||
set_vram_to = VRAMState.NO_VRAM
|
set_vram_to = VRAMState.NO_VRAM
|
||||||
elif args.highvram:
|
elif args.highvram or args.gpu_only:
|
||||||
vram_state = VRAMState.HIGH_VRAM
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
|
|
||||||
FORCE_FP32 = False
|
FORCE_FP32 = False
|
||||||
@ -307,6 +307,12 @@ def unload_if_low_vram(model):
|
|||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def text_encoder_device():
|
||||||
|
if args.gpu_only:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
return dev.type
|
return dev.type
|
||||||
|
|||||||
32
comfy/ops.py
Normal file
32
comfy/ops.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
class Linear(torch.nn.Module):
|
||||||
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
|
device=None, dtype=None) -> None:
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.nn.functional.linear(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
class Conv2d(torch.nn.Conv2d):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way
|
||||||
|
old_torch_nn_linear = torch.nn.Linear
|
||||||
|
torch.nn.Linear = Linear
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.nn.Linear = old_torch_nn_linear
|
||||||
@ -273,7 +273,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
max_total_area = model_management.maximum_batch_area()
|
max_total_area = model_management.maximum_batch_area()
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
|
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
||||||
|
return model_options["sampler_cfg_function"](args)
|
||||||
else:
|
else:
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
@ -649,7 +650,10 @@ class KSampler:
|
|||||||
self.model_k.latent_image = latent_image
|
self.model_k.latent_image = latent_image
|
||||||
self.model_k.noise = noise
|
self.model_k.noise = noise
|
||||||
|
|
||||||
noise = noise * sigmas[0]
|
if max_denoise:
|
||||||
|
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
|
else:
|
||||||
|
noise = noise * sigmas[0]
|
||||||
|
|
||||||
k_callback = None
|
k_callback = None
|
||||||
total_steps = len(sigmas) - 1
|
total_steps = len(sigmas) - 1
|
||||||
|
|||||||
104
comfy/sd.py
104
comfy/sd.py
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
|
|
||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
@ -85,7 +86,7 @@ LORA_UNET_MAP_RESNET = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def load_lora(path, to_load):
|
def load_lora(path, to_load):
|
||||||
lora = utils.load_torch_file(path)
|
lora = utils.load_torch_file(path, safe_load=True)
|
||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
@ -301,20 +302,24 @@ class ModelPatcher:
|
|||||||
t = model_sd[k]
|
t = model_sd[k]
|
||||||
size += t.nelement() * t.element_size()
|
size += t.nelement() * t.element_size()
|
||||||
self.size = size
|
self.size = size
|
||||||
|
self.model_keys = set(model_sd.keys())
|
||||||
return size
|
return size
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.size)
|
n = ModelPatcher(self.model, self.size)
|
||||||
n.patches = self.patches[:]
|
n.patches = self.patches[:]
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
|
n.model_keys = self.model_keys
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def set_model_tomesd(self, ratio):
|
def set_model_tomesd(self, ratio):
|
||||||
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
|
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
|
else:
|
||||||
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
|
||||||
def set_model_patch(self, patch, name):
|
def set_model_patch(self, patch, name):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
@ -328,6 +333,9 @@ class ModelPatcher:
|
|||||||
def set_model_attn2_patch(self, patch):
|
def set_model_attn2_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_output_patch")
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
def model_patches_to(self, device):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" in to:
|
if "patches" in to:
|
||||||
@ -341,17 +349,25 @@ class ModelPatcher:
|
|||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
|
|
||||||
def add_patches(self, patches, strength=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
p = {}
|
p = {}
|
||||||
model_sd = self.model.state_dict()
|
|
||||||
for k in patches:
|
for k in patches:
|
||||||
if k in model_sd:
|
if k in self.model_keys:
|
||||||
p[k] = patches[k]
|
p[k] = patches[k]
|
||||||
self.patches += [(strength, p)]
|
self.patches += [(strength_patch, p, strength_model)]
|
||||||
return p.keys()
|
return p.keys()
|
||||||
|
|
||||||
|
def model_state_dict(self, filter_prefix=None):
|
||||||
|
sd = self.model.state_dict()
|
||||||
|
keys = list(sd.keys())
|
||||||
|
if filter_prefix is not None:
|
||||||
|
for k in keys:
|
||||||
|
if not k.startswith(filter_prefix):
|
||||||
|
sd.pop(k)
|
||||||
|
return sd
|
||||||
|
|
||||||
def patch_model(self):
|
def patch_model(self):
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model_state_dict()
|
||||||
for p in self.patches:
|
for p in self.patches:
|
||||||
for k in p[1]:
|
for k in p[1]:
|
||||||
v = p[1][k]
|
v = p[1][k]
|
||||||
@ -365,8 +381,14 @@ class ModelPatcher:
|
|||||||
self.backup[key] = weight.clone()
|
self.backup[key] = weight.clone()
|
||||||
|
|
||||||
alpha = p[0]
|
alpha = p[0]
|
||||||
|
strength_model = p[2]
|
||||||
|
|
||||||
if len(v) == 4: #lora/locon
|
if strength_model != 1.0:
|
||||||
|
weight *= strength_model
|
||||||
|
|
||||||
|
if len(v) == 1:
|
||||||
|
weight += alpha * (v[0]).type(weight.dtype).to(weight.device)
|
||||||
|
elif len(v) == 4: #lora/locon
|
||||||
mat1 = v[0]
|
mat1 = v[0]
|
||||||
mat2 = v[1]
|
mat2 = v[1]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
@ -422,7 +444,7 @@ class ModelPatcher:
|
|||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
||||||
return self.model
|
return self.model
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model_state_dict()
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
model_sd[k][:] = self.backup[k]
|
model_sd[k][:] = self.backup[k]
|
||||||
@ -464,7 +486,11 @@ class CLIP:
|
|||||||
clip = sd1_clip.SD1ClipModel
|
clip = sd1_clip.SD1ClipModel
|
||||||
tokenizer = sd1_clip.SD1Tokenizer
|
tokenizer = sd1_clip.SD1Tokenizer
|
||||||
|
|
||||||
|
self.device = model_management.text_encoder_device()
|
||||||
|
params["device"] = self.device
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
self.cond_stage_model = self.cond_stage_model.to(self.device)
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model)
|
self.patcher = ModelPatcher(self.cond_stage_model)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
@ -544,6 +570,19 @@ class VAE:
|
|||||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
|
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
|
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
|
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
|
pbar = utils.ProgressBar(steps)
|
||||||
|
|
||||||
|
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor
|
||||||
|
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
|
samples /= 3.0
|
||||||
|
return samples
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
model_management.unload_model()
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
@ -574,28 +613,29 @@ class VAE:
|
|||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
model_management.unload_model()
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor
|
try:
|
||||||
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
|
batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
||||||
|
batch_number = max(1, batch_number)
|
||||||
|
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||||
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
|
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
|
||||||
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor
|
||||||
|
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.cpu()
|
||||||
samples = samples.cpu()
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
model_management.unload_model()
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
|
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||||
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
|
||||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
|
||||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
|
||||||
pbar = utils.ProgressBar(steps)
|
|
||||||
|
|
||||||
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
|
||||||
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
|
||||||
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
|
||||||
samples /= 3.0
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.cpu()
|
||||||
samples = samples.cpu()
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
@ -708,7 +748,7 @@ class ControlNet:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
pth = False
|
pth = False
|
||||||
sd2 = False
|
sd2 = False
|
||||||
@ -910,7 +950,7 @@ class StyleModel:
|
|||||||
|
|
||||||
|
|
||||||
def load_style_model(ckpt_path):
|
def load_style_model(ckpt_path):
|
||||||
model_data = utils.load_torch_file(ckpt_path)
|
model_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
keys = model_data.keys()
|
keys = model_data.keys()
|
||||||
if "style_embedding" in keys:
|
if "style_embedding" in keys:
|
||||||
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||||
@ -921,7 +961,7 @@ def load_style_model(ckpt_path):
|
|||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_path, embedding_directory=None):
|
def load_clip(ckpt_path, embedding_directory=None):
|
||||||
clip_data = utils.load_torch_file(ckpt_path)
|
clip_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
config = {}
|
config = {}
|
||||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
||||||
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||||
@ -932,7 +972,7 @@ def load_clip(ckpt_path, embedding_directory=None):
|
|||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
data = utils.load_torch_file(ckpt_path)
|
data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
model = gligen.load_gligen(data)
|
model = gligen.load_gligen(data)
|
||||||
if model_management.should_use_fp16():
|
if model_management.should_use_fp16():
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@ -1097,7 +1137,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||||
|
|
||||||
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||||
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
|
||||||
|
|
||||||
unclip_model = False
|
unclip_model = False
|
||||||
inpaint_model = False
|
inpaint_model = False
|
||||||
@ -1107,11 +1146,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
sd_config["embedding_dropout"] = 0.25
|
sd_config["embedding_dropout"] = 0.25
|
||||||
sd_config["conditioning_key"] = 'crossattn-adm'
|
sd_config["conditioning_key"] = 'crossattn-adm'
|
||||||
unclip_model = True
|
unclip_model = True
|
||||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
|
||||||
elif unet_config["in_channels"] > 4: #inpainting model
|
elif unet_config["in_channels"] > 4: #inpainting model
|
||||||
sd_config["conditioning_key"] = "hybrid"
|
sd_config["conditioning_key"] = "hybrid"
|
||||||
sd_config["finetune_keys"] = None
|
sd_config["finetune_keys"] = None
|
||||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
|
||||||
inpaint_model = True
|
inpaint_model = True
|
||||||
else:
|
else:
|
||||||
sd_config["conditioning_key"] = "crossattn"
|
sd_config["conditioning_key"] = "crossattn"
|
||||||
@ -1143,7 +1180,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
|
|
||||||
if fp16:
|
|
||||||
model = model.half()
|
|
||||||
|
|
||||||
return (ModelPatcher(model), clip, vae, clipvision)
|
return (ModelPatcher(model), clip, vae, clipvision)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
|
||||||
|
import comfy.ops
|
||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
@ -19,7 +20,7 @@ class ClipTokenWeightEncoder:
|
|||||||
output += [z]
|
output += [z]
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
return self.encode(self.empty_tokens)
|
return self.encode(self.empty_tokens)
|
||||||
return torch.cat(output, dim=-2)
|
return torch.cat(output, dim=-2).cpu()
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
@ -38,7 +39,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if textmodel_json_config is None:
|
if textmodel_json_config is None:
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||||
self.transformer = CLIPTextModel(config)
|
with comfy.ops.use_comfy_ops():
|
||||||
|
with modeling_utils.no_init_weights():
|
||||||
|
self.transformer = CLIPTextModel(config)
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
|
import comfy.checkpoint_pickle
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
@ -14,7 +15,7 @@ def load_torch_file(ckpt, safe_load=False):
|
|||||||
if safe_load:
|
if safe_load:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
|
|||||||
@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
def __init__(self, hypernet, strength):
|
def __init__(self, hypernet, strength):
|
||||||
self.hypernet = hypernet
|
self.hypernet = hypernet
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
def __call__(self, current_index, q, k, v):
|
def __call__(self, q, k, v, extra_options):
|
||||||
dim = k.shape[-1]
|
dim = k.shape[-1]
|
||||||
if dim in self.hypernet:
|
if dim in self.hypernet:
|
||||||
hn = self.hypernet[dim]
|
hn = self.hypernet[dim]
|
||||||
|
|||||||
55
comfy_extras/nodes_model_merging.py
Normal file
55
comfy_extras/nodes_model_merging.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
|
||||||
|
|
||||||
|
class ModelMergeSimple:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",),
|
||||||
|
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "merge"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/model_merging"
|
||||||
|
|
||||||
|
def merge(self, model1, model2, ratio):
|
||||||
|
m = model1.clone()
|
||||||
|
sd = model2.model_state_dict("diffusion_model.")
|
||||||
|
for k in sd:
|
||||||
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
class ModelMergeBlocks:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",),
|
||||||
|
"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "merge"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/model_merging"
|
||||||
|
|
||||||
|
def merge(self, model1, model2, **kwargs):
|
||||||
|
m = model1.clone()
|
||||||
|
sd = model2.model_state_dict("diffusion_model.")
|
||||||
|
default_ratio = next(iter(kwargs.values()))
|
||||||
|
|
||||||
|
for k in sd:
|
||||||
|
ratio = default_ratio
|
||||||
|
k_unet = k[len("diffusion_model."):]
|
||||||
|
|
||||||
|
for arg in kwargs:
|
||||||
|
if k_unet.startswith(arg):
|
||||||
|
ratio = kwargs[arg]
|
||||||
|
|
||||||
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ModelMergeSimple": ModelMergeSimple,
|
||||||
|
"ModelMergeBlocks": ModelMergeBlocks
|
||||||
|
}
|
||||||
15
execution.py
15
execution.py
@ -310,7 +310,6 @@ class PromptExecutor:
|
|||||||
else:
|
else:
|
||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
execution_start_time = time.perf_counter()
|
|
||||||
if self.server.client_id is not None:
|
if self.server.client_id is not None:
|
||||||
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
||||||
|
|
||||||
@ -358,12 +357,7 @@ class PromptExecutor:
|
|||||||
for x in executed:
|
for x in executed:
|
||||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||||
self.server.last_node_id = None
|
self.server.last_node_id = None
|
||||||
if self.server.client_id is not None:
|
|
||||||
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
|
||||||
|
|
||||||
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
|
||||||
gc.collect()
|
|
||||||
comfy.model_management.soft_empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item, validated):
|
def validate_inputs(prompt, item, validated):
|
||||||
@ -728,9 +722,14 @@ class PromptQueue:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_history(self):
|
def get_history(self, prompt_id=None):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
return copy.deepcopy(self.history)
|
if prompt_id is None:
|
||||||
|
return copy.deepcopy(self.history)
|
||||||
|
elif prompt_id in self.history:
|
||||||
|
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
def wipe_history(self):
|
def wipe_history(self):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
|
|||||||
13
main.py
13
main.py
@ -3,6 +3,8 @@ import itertools
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -28,15 +30,22 @@ import folder_paths
|
|||||||
import server
|
import server
|
||||||
from server import BinaryEventTypes
|
from server import BinaryEventTypes
|
||||||
from nodes import init_custom_nodes
|
from nodes import init_custom_nodes
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
def prompt_worker(q, server):
|
def prompt_worker(q, server):
|
||||||
e = execution.PromptExecutor(server)
|
e = execution.PromptExecutor(server)
|
||||||
while True:
|
while True:
|
||||||
item, item_id = q.get()
|
item, item_id = q.get()
|
||||||
e.execute(item[2], item[1], item[3], item[4])
|
execution_start_time = time.perf_counter()
|
||||||
|
prompt_id = item[1]
|
||||||
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
q.task_done(item_id, e.outputs_ui)
|
q.task_done(item_id, e.outputs_ui)
|
||||||
|
if server.client_id is not None:
|
||||||
|
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
|
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||||
|
gc.collect()
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||||
|
|||||||
29
nodes.py
29
nodes.py
@ -756,7 +756,7 @@ class RepeatLatentBatch:
|
|||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class LatentUpscale:
|
class LatentUpscale:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -776,7 +776,7 @@ class LatentUpscale:
|
|||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class LatentUpscaleBy:
|
class LatentUpscaleBy:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1172,7 +1172,7 @@ class LoadImageMask:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1192,6 +1192,26 @@ class ImageScale:
|
|||||||
s = s.movedim(1,-1)
|
s = s.movedim(1,-1)
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
class ImageScaleBy:
|
||||||
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
|
||||||
|
"scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}}
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "upscale"
|
||||||
|
|
||||||
|
CATEGORY = "image/upscaling"
|
||||||
|
|
||||||
|
def upscale(self, image, upscale_method, scale_by):
|
||||||
|
samples = image.movedim(-1,1)
|
||||||
|
width = round(samples.shape[3] * scale_by)
|
||||||
|
height = round(samples.shape[2] * scale_by)
|
||||||
|
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||||
|
s = s.movedim(1,-1)
|
||||||
|
return (s,)
|
||||||
|
|
||||||
class ImageInvert:
|
class ImageInvert:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1290,6 +1310,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LoadImage": LoadImage,
|
"LoadImage": LoadImage,
|
||||||
"LoadImageMask": LoadImageMask,
|
"LoadImageMask": LoadImageMask,
|
||||||
"ImageScale": ImageScale,
|
"ImageScale": ImageScale,
|
||||||
|
"ImageScaleBy": ImageScaleBy,
|
||||||
"ImageInvert": ImageInvert,
|
"ImageInvert": ImageInvert,
|
||||||
"ImagePadForOutpaint": ImagePadForOutpaint,
|
"ImagePadForOutpaint": ImagePadForOutpaint,
|
||||||
"ConditioningAverage ": ConditioningAverage ,
|
"ConditioningAverage ": ConditioningAverage ,
|
||||||
@ -1371,6 +1392,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LoadImage": "Load Image",
|
"LoadImage": "Load Image",
|
||||||
"LoadImageMask": "Load Image (as Mask)",
|
"LoadImageMask": "Load Image (as Mask)",
|
||||||
"ImageScale": "Upscale Image",
|
"ImageScale": "Upscale Image",
|
||||||
|
"ImageScaleBy": "Upscale Image By",
|
||||||
"ImageUpscaleWithModel": "Upscale Image (using Model)",
|
"ImageUpscaleWithModel": "Upscale Image (using Model)",
|
||||||
"ImageInvert": "Invert Image",
|
"ImageInvert": "Invert Image",
|
||||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||||
@ -1437,4 +1459,5 @@ def init_custom_nodes():
|
|||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
|||||||
@ -2,10 +2,11 @@ torch
|
|||||||
torchdiffeq
|
torchdiffeq
|
||||||
torchsde
|
torchsde
|
||||||
einops
|
einops
|
||||||
open-clip-torch
|
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
safetensors>=0.3.0
|
safetensors>=0.3.0
|
||||||
pytorch_lightning
|
|
||||||
aiohttp
|
aiohttp
|
||||||
accelerate
|
accelerate
|
||||||
pyyaml
|
pyyaml
|
||||||
|
Pillow
|
||||||
|
scipy
|
||||||
|
tqdm
|
||||||
|
|||||||
164
script_examples/websockets_api_example.py
Normal file
164
script_examples/websockets_api_example.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
#This is an example that uses the websockets api to know when a prompt execution is done
|
||||||
|
#Once the prompt execution is done it downloads the images using the /history endpoint
|
||||||
|
|
||||||
|
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
import urllib.request
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
server_address = "127.0.0.1:8188"
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
def queue_prompt(prompt):
|
||||||
|
p = {"prompt": prompt, "client_id": client_id}
|
||||||
|
data = json.dumps(p).encode('utf-8')
|
||||||
|
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
||||||
|
return json.loads(urllib.request.urlopen(req).read())
|
||||||
|
|
||||||
|
def get_image(filename, subfolder, folder_type):
|
||||||
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||||
|
url_values = urllib.parse.urlencode(data)
|
||||||
|
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
|
||||||
|
return response.read()
|
||||||
|
|
||||||
|
def get_history(prompt_id):
|
||||||
|
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
|
||||||
|
return json.loads(response.read())
|
||||||
|
|
||||||
|
def get_images(ws, prompt):
|
||||||
|
prompt_id = queue_prompt(prompt)['prompt_id']
|
||||||
|
output_images = {}
|
||||||
|
while True:
|
||||||
|
out = ws.recv()
|
||||||
|
if isinstance(out, str):
|
||||||
|
message = json.loads(out)
|
||||||
|
if message['type'] == 'executing':
|
||||||
|
data = message['data']
|
||||||
|
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||||
|
break #Execution is done
|
||||||
|
else:
|
||||||
|
continue #previews are binary data
|
||||||
|
|
||||||
|
history = get_history(prompt_id)[prompt_id]
|
||||||
|
for o in history['outputs']:
|
||||||
|
for node_id in history['outputs']:
|
||||||
|
node_output = history['outputs'][node_id]
|
||||||
|
if 'images' in node_output:
|
||||||
|
images_output = []
|
||||||
|
for image in node_output['images']:
|
||||||
|
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
|
images_output.append(image_data)
|
||||||
|
output_images[node_id] = images_output
|
||||||
|
|
||||||
|
return output_images
|
||||||
|
|
||||||
|
prompt_text = """
|
||||||
|
{
|
||||||
|
"3": {
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"inputs": {
|
||||||
|
"cfg": 8,
|
||||||
|
"denoise": 1,
|
||||||
|
"latent_image": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"model": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"seed": 8566257,
|
||||||
|
"steps": 20
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"inputs": {
|
||||||
|
"batch_size": 1,
|
||||||
|
"height": 512,
|
||||||
|
"width": 512
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": {
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"text": "masterpiece best quality girl"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": {
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"text": "bad hands"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"4",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = json.loads(prompt_text)
|
||||||
|
#set the text prompt for our positive CLIPTextEncode
|
||||||
|
prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
|
||||||
|
|
||||||
|
#set the seed for our KSampler node
|
||||||
|
prompt["3"]["inputs"]["seed"] = 5
|
||||||
|
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
|
||||||
|
images = get_images(ws, prompt)
|
||||||
|
|
||||||
|
#Commented out code to display the output images:
|
||||||
|
|
||||||
|
# for node_id in images:
|
||||||
|
# for image_data in images[node_id]:
|
||||||
|
# from PIL import Image
|
||||||
|
# import io
|
||||||
|
# image = Image.open(io.BytesIO(image_data))
|
||||||
|
# image.show()
|
||||||
|
|
||||||
18
server.py
18
server.py
@ -30,6 +30,11 @@ import comfy.model_management
|
|||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
|
|
||||||
|
async def send_socket_catch_exception(function, message):
|
||||||
|
try:
|
||||||
|
await function(message)
|
||||||
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||||
|
print("send error:", err)
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
@ -372,6 +377,11 @@ class PromptServer():
|
|||||||
async def get_history(request):
|
async def get_history(request):
|
||||||
return web.json_response(self.prompt_queue.get_history())
|
return web.json_response(self.prompt_queue.get_history())
|
||||||
|
|
||||||
|
@routes.get("/history/{prompt_id}")
|
||||||
|
async def get_history(request):
|
||||||
|
prompt_id = request.match_info.get("prompt_id", None)
|
||||||
|
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
|
||||||
|
|
||||||
@routes.get("/queue")
|
@routes.get("/queue")
|
||||||
async def get_queue(request):
|
async def get_queue(request):
|
||||||
queue_info = {}
|
queue_info = {}
|
||||||
@ -482,18 +492,18 @@ class PromptServer():
|
|||||||
|
|
||||||
if sid is None:
|
if sid is None:
|
||||||
for ws in self.sockets.values():
|
for ws in self.sockets.values():
|
||||||
await ws.send_bytes(message)
|
await send_socket_catch_exception(ws.send_bytes, message)
|
||||||
elif sid in self.sockets:
|
elif sid in self.sockets:
|
||||||
await self.sockets[sid].send_bytes(message)
|
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
|
||||||
|
|
||||||
async def send_json(self, event, data, sid=None):
|
async def send_json(self, event, data, sid=None):
|
||||||
message = {"type": event, "data": data}
|
message = {"type": event, "data": data}
|
||||||
|
|
||||||
if sid is None:
|
if sid is None:
|
||||||
for ws in self.sockets.values():
|
for ws in self.sockets.values():
|
||||||
await ws.send_json(message)
|
await send_socket_catch_exception(ws.send_json, message)
|
||||||
elif sid in self.sockets:
|
elif sid in self.sockets:
|
||||||
await self.sockets[sid].send_json(message)
|
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||||
|
|
||||||
def send_sync(self, event, data, sid=None):
|
def send_sync(self, event, data, sid=None):
|
||||||
self.loop.call_soon_threadsafe(
|
self.loop.call_soon_threadsafe(
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import { app } from "/scripts/app.js";
|
import {app} from "/scripts/app.js";
|
||||||
import { $el } from "/scripts/ui.js";
|
import {$el} from "/scripts/ui.js";
|
||||||
import { api } from "/scripts/api.js";
|
|
||||||
|
|
||||||
// Manage color palettes
|
// Manage color palettes
|
||||||
|
|
||||||
@ -24,6 +23,8 @@ const colorPalettes = {
|
|||||||
"TAESD": "#DCC274", // cheesecake
|
"TAESD": "#DCC274", // cheesecake
|
||||||
},
|
},
|
||||||
"litegraph_base": {
|
"litegraph_base": {
|
||||||
|
"BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAAGXRFWHRTb2Z0d2FyZQBBZG9iZSBJbWFnZVJlYWR5ccllPAAAAQBJREFUeNrs1rEKwjAUhlETUkj3vP9rdmr1Ysammk2w5wdxuLgcMHyptfawuZX4pJSWZTnfnu/lnIe/jNNxHHGNn//HNbbv+4dr6V+11uF527arU7+u63qfa/bnmh8sWLBgwYJlqRf8MEptXPBXJXa37BSl3ixYsGDBMliwFLyCV/DeLIMFCxYsWLBMwSt4Be/NggXLYMGCBUvBK3iNruC9WbBgwYJlsGApeAWv4L1ZBgsWLFiwYJmCV/AK3psFC5bBggULloJX8BpdwXuzYMGCBctgwVLwCl7Be7MMFixYsGDBsu8FH1FaSmExVfAxBa/gvVmwYMGCZbBg/W4vAQYA5tRF9QYlv/QAAAAASUVORK5CYII=",
|
||||||
|
"CLEAR_BACKGROUND_COLOR": "#222",
|
||||||
"NODE_TITLE_COLOR": "#999",
|
"NODE_TITLE_COLOR": "#999",
|
||||||
"NODE_SELECTED_TITLE_COLOR": "#FFF",
|
"NODE_SELECTED_TITLE_COLOR": "#FFF",
|
||||||
"NODE_TEXT_SIZE": 14,
|
"NODE_TEXT_SIZE": 14,
|
||||||
@ -55,7 +56,9 @@ const colorPalettes = {
|
|||||||
"descrip-text": "#999",
|
"descrip-text": "#999",
|
||||||
"drag-text": "#ccc",
|
"drag-text": "#ccc",
|
||||||
"error-text": "#ff4444",
|
"error-text": "#ff4444",
|
||||||
"border-color": "#4e4e4e"
|
"border-color": "#4e4e4e",
|
||||||
|
"tr-even-bg-color": "#222",
|
||||||
|
"tr-odd-bg-color": "#353535",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -77,6 +80,8 @@ const colorPalettes = {
|
|||||||
"VAE": "#FF7043", // deep orange
|
"VAE": "#FF7043", // deep orange
|
||||||
},
|
},
|
||||||
"litegraph_base": {
|
"litegraph_base": {
|
||||||
|
"BACKGROUND_IMAGE": "data:image/gif;base64,R0lGODlhZABkALMAAAAAAP///+vr6+rq6ujo6Ofn5+bm5uXl5d3d3f///wAAAAAAAAAAAAAAAAAAAAAAACH5BAEAAAkALAAAAABkAGQAAAT/UMhJq7046827HkcoHkYxjgZhnGG6si5LqnIM0/fL4qwwIMAg0CAsEovBIxKhRDaNy2GUOX0KfVFrssrNdpdaqTeKBX+dZ+jYvEaTf+y4W66mC8PUdrE879f9d2mBeoNLfH+IhYBbhIx2jkiHiomQlGKPl4uZe3CaeZifnnijgkESBqipqqusra6vsLGys62SlZO4t7qbuby7CLa+wqGWxL3Gv3jByMOkjc2lw8vOoNSi0czAncXW3Njdx9Pf48/Z4Kbbx+fQ5evZ4u3k1fKR6cn03vHlp7T9/v8A/8Gbp4+gwXoFryXMB2qgwoMMHyKEqA5fxX322FG8tzBcRnMW/zlulPbRncmQGidKjMjyYsOSKEF2FBlJQMCbOHP6c9iSZs+UnGYCdbnSo1CZI5F64kn0p1KnTH02nSoV3dGTV7FFHVqVq1dtWcMmVQZTbNGu72zqXMuW7danVL+6e4t1bEy6MeueBYLXrNO5Ze36jQtWsOG97wIj1vt3St/DjTEORss4nNq2mDP3e7w4r1bFkSET5hy6s2TRlD2/mSxXtSHQhCunXo26NevCpmvD/UU6tuullzULH76q92zdZG/Ltv1a+W+osI/nRmyc+fRi1Xdbh+68+0vv10dH3+77KD/i6IdnX669/frn5Zsjh4/2PXju8+8bzc9/6fj27LFnX11/+IUnXWl7BJfegm79FyB9JOl3oHgSklefgxAC+FmFGpqHIYcCfkhgfCohSKKJVo044YUMttggiBkmp6KFXw1oII24oYhjiDByaKOOHcp3Y5BD/njikSkO+eBREQAAOw==",
|
||||||
|
"CLEAR_BACKGROUND_COLOR": "lightgray",
|
||||||
"NODE_TITLE_COLOR": "#222",
|
"NODE_TITLE_COLOR": "#222",
|
||||||
"NODE_SELECTED_TITLE_COLOR": "#000",
|
"NODE_SELECTED_TITLE_COLOR": "#000",
|
||||||
"NODE_TEXT_SIZE": 14,
|
"NODE_TEXT_SIZE": 14,
|
||||||
@ -108,7 +113,9 @@ const colorPalettes = {
|
|||||||
"descrip-text": "#444",
|
"descrip-text": "#444",
|
||||||
"drag-text": "#555",
|
"drag-text": "#555",
|
||||||
"error-text": "#F44336",
|
"error-text": "#F44336",
|
||||||
"border-color": "#888"
|
"border-color": "#888",
|
||||||
|
"tr-even-bg-color": "#f9f9f9",
|
||||||
|
"tr-odd-bg-color": "#fff",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -162,7 +169,9 @@ const colorPalettes = {
|
|||||||
"descrip-text": "#586e75", // Base01
|
"descrip-text": "#586e75", // Base01
|
||||||
"drag-text": "#839496", // Base0
|
"drag-text": "#839496", // Base0
|
||||||
"error-text": "#dc322f", // Solarized Red
|
"error-text": "#dc322f", // Solarized Red
|
||||||
"border-color": "#657b83" // Base00
|
"border-color": "#657b83", // Base00
|
||||||
|
"tr-even-bg-color": "#002b36",
|
||||||
|
"tr-odd-bg-color": "#073642",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -191,7 +200,7 @@ app.registerExtension({
|
|||||||
const nodeData = defs[nodeId];
|
const nodeData = defs[nodeId];
|
||||||
|
|
||||||
var inputs = nodeData["input"]["required"];
|
var inputs = nodeData["input"]["required"];
|
||||||
if (nodeData["input"]["optional"] != undefined){
|
if (nodeData["input"]["optional"] !== undefined) {
|
||||||
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"])
|
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,7 +220,7 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
|
|
||||||
return types;
|
return types;
|
||||||
};
|
}
|
||||||
|
|
||||||
function completeColorPalette(colorPalette) {
|
function completeColorPalette(colorPalette) {
|
||||||
var types = getSlotTypes();
|
var types = getSlotTypes();
|
||||||
@ -225,19 +234,16 @@ app.registerExtension({
|
|||||||
colorPalette.colors.node_slot = sortObjectKeys(colorPalette.colors.node_slot);
|
colorPalette.colors.node_slot = sortObjectKeys(colorPalette.colors.node_slot);
|
||||||
|
|
||||||
return colorPalette;
|
return colorPalette;
|
||||||
};
|
}
|
||||||
|
|
||||||
const getColorPaletteTemplate = async () => {
|
const getColorPaletteTemplate = async () => {
|
||||||
let colorPalette = {
|
let colorPalette = {
|
||||||
"id": "my_color_palette_unique_id",
|
"id": "my_color_palette_unique_id",
|
||||||
"name": "My Color Palette",
|
"name": "My Color Palette",
|
||||||
"colors": {
|
"colors": {
|
||||||
"node_slot": {
|
"node_slot": {},
|
||||||
},
|
"litegraph_base": {},
|
||||||
"litegraph_base": {
|
"comfy_base": {}
|
||||||
},
|
|
||||||
"comfy_base": {
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -266,32 +272,32 @@ app.registerExtension({
|
|||||||
};
|
};
|
||||||
|
|
||||||
const addCustomColorPalette = async (colorPalette) => {
|
const addCustomColorPalette = async (colorPalette) => {
|
||||||
if (typeof(colorPalette) !== "object") {
|
if (typeof (colorPalette) !== "object") {
|
||||||
app.ui.dialog.show("Invalid color palette");
|
alert("Invalid color palette.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!colorPalette.id) {
|
if (!colorPalette.id) {
|
||||||
app.ui.dialog.show("Color palette missing id");
|
alert("Color palette missing id.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!colorPalette.name) {
|
if (!colorPalette.name) {
|
||||||
app.ui.dialog.show("Color palette missing name");
|
alert("Color palette missing name.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!colorPalette.colors) {
|
if (!colorPalette.colors) {
|
||||||
app.ui.dialog.show("Color palette missing colors");
|
alert("Color palette missing colors.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (colorPalette.colors.node_slot && typeof(colorPalette.colors.node_slot) !== "object") {
|
if (colorPalette.colors.node_slot && typeof (colorPalette.colors.node_slot) !== "object") {
|
||||||
app.ui.dialog.show("Invalid color palette colors.node_slot");
|
alert("Invalid color palette colors.node_slot.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let customColorPalettes = getCustomColorPalettes();
|
const customColorPalettes = getCustomColorPalettes();
|
||||||
customColorPalettes[colorPalette.id] = colorPalette;
|
customColorPalettes[colorPalette.id] = colorPalette;
|
||||||
setCustomColorPalettes(customColorPalettes);
|
setCustomColorPalettes(customColorPalettes);
|
||||||
|
|
||||||
@ -301,14 +307,18 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
els.select.append($el("option", { textContent: colorPalette.name + " (custom)", value: "custom_" + colorPalette.id, selected: true }));
|
els.select.append($el("option", {
|
||||||
|
textContent: colorPalette.name + " (custom)",
|
||||||
|
value: "custom_" + colorPalette.id,
|
||||||
|
selected: true
|
||||||
|
}));
|
||||||
|
|
||||||
setColorPalette("custom_" + colorPalette.id);
|
setColorPalette("custom_" + colorPalette.id);
|
||||||
await loadColorPalette(colorPalette);
|
await loadColorPalette(colorPalette);
|
||||||
};
|
};
|
||||||
|
|
||||||
const deleteCustomColorPalette = async (colorPaletteId) => {
|
const deleteCustomColorPalette = async (colorPaletteId) => {
|
||||||
let customColorPalettes = getCustomColorPalettes();
|
const customColorPalettes = getCustomColorPalettes();
|
||||||
delete customColorPalettes[colorPaletteId];
|
delete customColorPalettes[colorPaletteId];
|
||||||
setCustomColorPalettes(customColorPalettes);
|
setCustomColorPalettes(customColorPalettes);
|
||||||
|
|
||||||
@ -350,7 +360,7 @@ app.registerExtension({
|
|||||||
if (colorPalette.colors.comfy_base) {
|
if (colorPalette.colors.comfy_base) {
|
||||||
const rootStyle = document.documentElement.style;
|
const rootStyle = document.documentElement.style;
|
||||||
for (const key in colorPalette.colors.comfy_base) {
|
for (const key in colorPalette.colors.comfy_base) {
|
||||||
rootStyle.setProperty('--' + key, colorPalette.colors.comfy_base[key]);
|
rootStyle.setProperty('--' + key, colorPalette.colors.comfy_base[key]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
app.canvas.draw(true, true);
|
app.canvas.draw(true, true);
|
||||||
@ -380,11 +390,10 @@ app.registerExtension({
|
|||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
type: "file",
|
type: "file",
|
||||||
accept: ".json",
|
accept: ".json",
|
||||||
style: { display: "none" },
|
style: {display: "none"},
|
||||||
parent: document.body,
|
parent: document.body,
|
||||||
onchange: () => {
|
onchange: () => {
|
||||||
let file = fileInput.files[0];
|
const file = fileInput.files[0];
|
||||||
|
|
||||||
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||||
const reader = new FileReader();
|
const reader = new FileReader();
|
||||||
reader.onload = async () => {
|
reader.onload = async () => {
|
||||||
@ -399,96 +408,116 @@ app.registerExtension({
|
|||||||
id,
|
id,
|
||||||
name: "Color Palette",
|
name: "Color Palette",
|
||||||
type: (name, setter, value) => {
|
type: (name, setter, value) => {
|
||||||
let options = [];
|
const options = [
|
||||||
|
...Object.values(colorPalettes).map(c=> $el("option", {
|
||||||
|
textContent: c.name,
|
||||||
|
value: c.id,
|
||||||
|
selected: c.id === value
|
||||||
|
})),
|
||||||
|
...Object.values(getCustomColorPalettes()).map(c=>$el("option", {
|
||||||
|
textContent: `${c.name} (custom)`,
|
||||||
|
value: `custom_${c.id}`,
|
||||||
|
selected: `custom_${c.id}` === value
|
||||||
|
})) ,
|
||||||
|
];
|
||||||
|
|
||||||
for (const c in colorPalettes) {
|
els.select = $el("select", {
|
||||||
const colorPalette = colorPalettes[c];
|
style: {
|
||||||
options.push($el("option", { textContent: colorPalette.name, value: colorPalette.id, selected: colorPalette.id === value }));
|
marginBottom: "0.15rem",
|
||||||
}
|
width: "100%",
|
||||||
|
},
|
||||||
|
onchange: (e) => {
|
||||||
|
setter(e.target.value);
|
||||||
|
}
|
||||||
|
}, options)
|
||||||
|
|
||||||
let customColorPalettes = getCustomColorPalettes();
|
return $el("tr", [
|
||||||
for (const c in customColorPalettes) {
|
$el("td", [
|
||||||
const colorPalette = customColorPalettes[c];
|
$el("label", {
|
||||||
options.push($el("option", { textContent: colorPalette.name + " (custom)", value: "custom_" + colorPalette.id, selected: "custom_" + colorPalette.id === value }));
|
for: id.replaceAll(".", "-"),
|
||||||
}
|
textContent: "Color palette",
|
||||||
|
}),
|
||||||
return $el("div", [
|
|
||||||
$el("label", { textContent: name || id }, [
|
|
||||||
els.select = $el("select", {
|
|
||||||
onchange: (e) => {
|
|
||||||
setter(e.target.value);
|
|
||||||
}
|
|
||||||
}, options)
|
|
||||||
]),
|
]),
|
||||||
$el("input", {
|
$el("td", [
|
||||||
type: "button",
|
els.select,
|
||||||
value: "Export",
|
$el("div", {
|
||||||
onclick: async () => {
|
style: {
|
||||||
const colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
|
display: "grid",
|
||||||
const colorPalette = await completeColorPalette(getColorPalette(colorPaletteId));
|
gap: "4px",
|
||||||
const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
|
gridAutoFlow: "column",
|
||||||
const blob = new Blob([json], { type: "application/json" });
|
},
|
||||||
const url = URL.createObjectURL(blob);
|
}, [
|
||||||
const a = $el("a", {
|
$el("input", {
|
||||||
href: url,
|
type: "button",
|
||||||
download: colorPaletteId + ".json",
|
value: "Export",
|
||||||
style: { display: "none" },
|
onclick: async () => {
|
||||||
parent: document.body,
|
const colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
|
||||||
});
|
const colorPalette = await completeColorPalette(getColorPalette(colorPaletteId));
|
||||||
a.click();
|
const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
|
||||||
setTimeout(function () {
|
const blob = new Blob([json], {type: "application/json"});
|
||||||
a.remove();
|
const url = URL.createObjectURL(blob);
|
||||||
window.URL.revokeObjectURL(url);
|
const a = $el("a", {
|
||||||
}, 0);
|
href: url,
|
||||||
},
|
download: colorPaletteId + ".json",
|
||||||
}),
|
style: {display: "none"},
|
||||||
$el("input", {
|
parent: document.body,
|
||||||
type: "button",
|
});
|
||||||
value: "Import",
|
a.click();
|
||||||
onclick: () => {
|
setTimeout(function () {
|
||||||
fileInput.click();
|
a.remove();
|
||||||
}
|
window.URL.revokeObjectURL(url);
|
||||||
}),
|
}, 0);
|
||||||
$el("input", {
|
},
|
||||||
type: "button",
|
}),
|
||||||
value: "Template",
|
$el("input", {
|
||||||
onclick: async () => {
|
type: "button",
|
||||||
const colorPalette = await getColorPaletteTemplate();
|
value: "Import",
|
||||||
const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
|
onclick: () => {
|
||||||
const blob = new Blob([json], { type: "application/json" });
|
fileInput.click();
|
||||||
const url = URL.createObjectURL(blob);
|
}
|
||||||
const a = $el("a", {
|
}),
|
||||||
href: url,
|
$el("input", {
|
||||||
download: "color_palette.json",
|
type: "button",
|
||||||
style: { display: "none" },
|
value: "Template",
|
||||||
parent: document.body,
|
onclick: async () => {
|
||||||
});
|
const colorPalette = await getColorPaletteTemplate();
|
||||||
a.click();
|
const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
|
||||||
setTimeout(function () {
|
const blob = new Blob([json], {type: "application/json"});
|
||||||
a.remove();
|
const url = URL.createObjectURL(blob);
|
||||||
window.URL.revokeObjectURL(url);
|
const a = $el("a", {
|
||||||
}, 0);
|
href: url,
|
||||||
}
|
download: "color_palette.json",
|
||||||
}),
|
style: {display: "none"},
|
||||||
$el("input", {
|
parent: document.body,
|
||||||
type: "button",
|
});
|
||||||
value: "Delete",
|
a.click();
|
||||||
onclick: async () => {
|
setTimeout(function () {
|
||||||
let colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
$el("input", {
|
||||||
|
type: "button",
|
||||||
|
value: "Delete",
|
||||||
|
onclick: async () => {
|
||||||
|
let colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
|
||||||
|
|
||||||
if (colorPalettes[colorPaletteId]) {
|
if (colorPalettes[colorPaletteId]) {
|
||||||
app.ui.dialog.show("You cannot delete built-in color palette");
|
alert("You cannot delete a built-in color palette.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (colorPaletteId.startsWith("custom_")) {
|
if (colorPaletteId.startsWith("custom_")) {
|
||||||
colorPaletteId = colorPaletteId.substr(7);
|
colorPaletteId = colorPaletteId.substr(7);
|
||||||
}
|
}
|
||||||
|
|
||||||
await deleteCustomColorPalette(colorPaletteId);
|
await deleteCustomColorPalette(colorPaletteId);
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
]);
|
]),
|
||||||
|
]),
|
||||||
|
])
|
||||||
},
|
},
|
||||||
defaultValue: defaultColorPaletteId,
|
defaultValue: defaultColorPaletteId,
|
||||||
async onChange(value) {
|
async onChange(value) {
|
||||||
@ -496,15 +525,25 @@ app.registerExtension({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (colorPalettes[value]) {
|
let palette = colorPalettes[value];
|
||||||
await loadColorPalette(colorPalettes[value]);
|
if (palette) {
|
||||||
|
await loadColorPalette(palette);
|
||||||
} else if (value.startsWith("custom_")) {
|
} else if (value.startsWith("custom_")) {
|
||||||
value = value.substr(7);
|
value = value.substr(7);
|
||||||
let customColorPalettes = getCustomColorPalettes();
|
let customColorPalettes = getCustomColorPalettes();
|
||||||
if (customColorPalettes[value]) {
|
if (customColorPalettes[value]) {
|
||||||
|
palette = customColorPalettes[value];
|
||||||
await loadColorPalette(customColorPalettes[value]);
|
await loadColorPalette(customColorPalettes[value]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let {BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR} = palette.colors.litegraph_base;
|
||||||
|
if (BACKGROUND_IMAGE === undefined || CLEAR_BACKGROUND_COLOR === undefined) {
|
||||||
|
const base = colorPalettes["dark"].colors.litegraph_base;
|
||||||
|
BACKGROUND_IMAGE = base.BACKGROUND_IMAGE;
|
||||||
|
CLEAR_BACKGROUND_COLOR = base.CLEAR_BACKGROUND_COLOR;
|
||||||
|
}
|
||||||
|
app.canvas.updateBackground(BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,38 +1,31 @@
|
|||||||
import { app } from "/scripts/app.js";
|
import {app} from "/scripts/app.js";
|
||||||
|
|
||||||
const id = "Comfy.Keybinds";
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: id,
|
name: "Comfy.Keybinds",
|
||||||
init() {
|
init() {
|
||||||
const keybindListener = function(event) {
|
const keybindListener = function (event) {
|
||||||
const modifierPressed = event.ctrlKey || event.metaKey;
|
const modifierPressed = event.ctrlKey || event.metaKey;
|
||||||
|
|
||||||
// Queue prompt using ctrl or command + enter
|
// Queue prompt using ctrl or command + enter
|
||||||
if (modifierPressed && (event.key === "Enter" || event.keyCode === 13 || event.keyCode === 10)) {
|
if (modifierPressed && event.key === "Enter") {
|
||||||
app.queuePrompt(event.shiftKey ? -1 : 0);
|
app.queuePrompt(event.shiftKey ? -1 : 0).then();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const target = event.composedPath()[0];
|
const target = event.composedPath()[0];
|
||||||
|
if (["INPUT", "TEXTAREA"].includes(target.tagName)) {
|
||||||
if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const modifierKeyIdMap = {
|
const modifierKeyIdMap = {
|
||||||
"s": "#comfy-save-button",
|
s: "#comfy-save-button",
|
||||||
83: "#comfy-save-button",
|
o: "#comfy-file-input",
|
||||||
"o": "#comfy-file-input",
|
Backspace: "#comfy-clear-button",
|
||||||
79: "#comfy-file-input",
|
Delete: "#comfy-clear-button",
|
||||||
"Backspace": "#comfy-clear-button",
|
d: "#comfy-load-default-button",
|
||||||
8: "#comfy-clear-button",
|
|
||||||
"Delete": "#comfy-clear-button",
|
|
||||||
46: "#comfy-clear-button",
|
|
||||||
"d": "#comfy-load-default-button",
|
|
||||||
68: "#comfy-load-default-button",
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const modifierKeybindId = modifierKeyIdMap[event.key] || modifierKeyIdMap[event.keyCode];
|
const modifierKeybindId = modifierKeyIdMap[event.key];
|
||||||
if (modifierPressed && modifierKeybindId) {
|
if (modifierPressed && modifierKeybindId) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
@ -47,24 +40,25 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close out of modals using escape
|
// Close out of modals using escape
|
||||||
if (event.key === "Escape" || event.keyCode === 27) {
|
if (event.key === "Escape") {
|
||||||
const modals = document.querySelectorAll(".comfy-modal");
|
const modals = document.querySelectorAll(".comfy-modal");
|
||||||
const modal = Array.from(modals).find(modal => window.getComputedStyle(modal).getPropertyValue("display") !== "none");
|
const modal = Array.from(modals).find(modal => window.getComputedStyle(modal).getPropertyValue("display") !== "none");
|
||||||
if (modal) {
|
if (modal) {
|
||||||
modal.style.display = "none";
|
modal.style.display = "none";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[...document.querySelectorAll("dialog")].forEach(d => {
|
||||||
|
d.close();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const keyIdMap = {
|
const keyIdMap = {
|
||||||
"q": "#comfy-view-queue-button",
|
q: "#comfy-view-queue-button",
|
||||||
81: "#comfy-view-queue-button",
|
h: "#comfy-view-history-button",
|
||||||
"h": "#comfy-view-history-button",
|
r: "#comfy-refresh-button",
|
||||||
72: "#comfy-view-history-button",
|
|
||||||
"r": "#comfy-refresh-button",
|
|
||||||
82: "#comfy-refresh-button",
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const buttonId = keyIdMap[event.key] || keyIdMap[event.keyCode];
|
const buttonId = keyIdMap[event.key];
|
||||||
if (buttonId) {
|
if (buttonId) {
|
||||||
const button = document.querySelector(buttonId);
|
const button = document.querySelector(buttonId);
|
||||||
button.click();
|
button.click();
|
||||||
|
|||||||
@ -24,9 +24,13 @@ app.registerExtension({
|
|||||||
// Ignore wildcard nodes as these will be updated to real types
|
// Ignore wildcard nodes as these will be updated to real types
|
||||||
const types = new Set(this.outputs[0].links.map((l) => app.graph.links[l].type).filter((t) => t !== "*"));
|
const types = new Set(this.outputs[0].links.map((l) => app.graph.links[l].type).filter((t) => t !== "*"));
|
||||||
if (types.size > 1) {
|
if (types.size > 1) {
|
||||||
|
const linksToDisconnect = [];
|
||||||
for (let i = 0; i < this.outputs[0].links.length - 1; i++) {
|
for (let i = 0; i < this.outputs[0].links.length - 1; i++) {
|
||||||
const linkId = this.outputs[0].links[i];
|
const linkId = this.outputs[0].links[i];
|
||||||
const link = app.graph.links[linkId];
|
const link = app.graph.links[linkId];
|
||||||
|
linksToDisconnect.push(link);
|
||||||
|
}
|
||||||
|
for (const link of linksToDisconnect) {
|
||||||
const node = app.graph.getNodeById(link.target_id);
|
const node = app.graph.getNodeById(link.target_id);
|
||||||
node.disconnectInput(link.target_slot);
|
node.disconnectInput(link.target_slot);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,7 +10,7 @@ app.registerExtension({
|
|||||||
LiteGraph.middle_click_slot_add_default_node = true;
|
LiteGraph.middle_click_slot_add_default_node = true;
|
||||||
this.suggestionsNumber = app.ui.settings.addSetting({
|
this.suggestionsNumber = app.ui.settings.addSetting({
|
||||||
id: "Comfy.NodeSuggestions.number",
|
id: "Comfy.NodeSuggestions.number",
|
||||||
name: "number of nodes suggestions",
|
name: "Number of nodes suggestions",
|
||||||
type: "slider",
|
type: "slider",
|
||||||
attrs: {
|
attrs: {
|
||||||
min: 1,
|
min: 1,
|
||||||
|
|||||||
@ -59,6 +59,10 @@ function convertToInput(node, widget, config) {
|
|||||||
widget: { name: widget.name, config },
|
widget: { name: widget.name, config },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
for (const widget of node.widgets) {
|
||||||
|
widget.last_y += LiteGraph.NODE_SLOT_HEIGHT;
|
||||||
|
}
|
||||||
|
|
||||||
// Restore original size but grow if needed
|
// Restore original size but grow if needed
|
||||||
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
||||||
}
|
}
|
||||||
@ -68,6 +72,10 @@ function convertToWidget(node, widget) {
|
|||||||
const sz = node.size;
|
const sz = node.size;
|
||||||
node.removeInput(node.inputs.findIndex((i) => i.widget?.name === widget.name));
|
node.removeInput(node.inputs.findIndex((i) => i.widget?.name === widget.name));
|
||||||
|
|
||||||
|
for (const widget of node.widgets) {
|
||||||
|
widget.last_y -= LiteGraph.NODE_SLOT_HEIGHT;
|
||||||
|
}
|
||||||
|
|
||||||
// Restore original size but grow if needed
|
// Restore original size but grow if needed
|
||||||
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
<link rel="stylesheet" type="text/css" href="lib/litegraph.css" />
|
<link rel="stylesheet" type="text/css" href="lib/litegraph.css" />
|
||||||
<link rel="stylesheet" type="text/css" href="style.css" />
|
<link rel="stylesheet" type="text/css" href="style.css" />
|
||||||
<script type="text/javascript" src="lib/litegraph.core.js"></script>
|
<script type="text/javascript" src="lib/litegraph.core.js"></script>
|
||||||
|
<script type="text/javascript" src="lib/litegraph.extensions.js" defer></script>
|
||||||
<script type="module">
|
<script type="module">
|
||||||
import { app } from "/scripts/app.js";
|
import { app } from "/scripts/app.js";
|
||||||
await app.setup();
|
await app.setup();
|
||||||
|
|||||||
21
web/lib/litegraph.extensions.js
Normal file
21
web/lib/litegraph.extensions.js
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
/**
|
||||||
|
* Changes the background color of the canvas.
|
||||||
|
*
|
||||||
|
* @method updateBackground
|
||||||
|
* @param {image} String
|
||||||
|
* @param {clearBackgroundColor} String
|
||||||
|
* @
|
||||||
|
*/
|
||||||
|
LGraphCanvas.prototype.updateBackground = function (image, clearBackgroundColor) {
|
||||||
|
this._bg_img = new Image();
|
||||||
|
this._bg_img.name = image;
|
||||||
|
this._bg_img.src = image;
|
||||||
|
this._bg_img.onload = () => {
|
||||||
|
this.draw(true, true);
|
||||||
|
};
|
||||||
|
this.background_image = image;
|
||||||
|
|
||||||
|
this.clear_background = true;
|
||||||
|
this.clear_background_color = clearBackgroundColor;
|
||||||
|
this._pattern = null
|
||||||
|
}
|
||||||
@ -120,6 +120,9 @@ class ComfyApi extends EventTarget {
|
|||||||
case "execution_error":
|
case "execution_error":
|
||||||
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
||||||
break;
|
break;
|
||||||
|
case "execution_cached":
|
||||||
|
this.dispatchEvent(new CustomEvent("execution_cached", { detail: msg.data }));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
if (this.#registered.has(msg.type)) {
|
if (this.#registered.has(msg.type)) {
|
||||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
||||||
|
|||||||
@ -1,19 +1,26 @@
|
|||||||
import { api } from "./api.js";
|
import {api} from "./api.js";
|
||||||
|
|
||||||
export function $el(tag, propsOrChildren, children) {
|
export function $el(tag, propsOrChildren, children) {
|
||||||
const split = tag.split(".");
|
const split = tag.split(".");
|
||||||
const element = document.createElement(split.shift());
|
const element = document.createElement(split.shift());
|
||||||
element.classList.add(...split);
|
if (split.length > 0) {
|
||||||
|
element.classList.add(...split);
|
||||||
|
}
|
||||||
|
|
||||||
if (propsOrChildren) {
|
if (propsOrChildren) {
|
||||||
if (Array.isArray(propsOrChildren)) {
|
if (Array.isArray(propsOrChildren)) {
|
||||||
element.append(...propsOrChildren);
|
element.append(...propsOrChildren);
|
||||||
} else {
|
} else {
|
||||||
const { parent, $: cb, dataset, style } = propsOrChildren;
|
const {parent, $: cb, dataset, style} = propsOrChildren;
|
||||||
delete propsOrChildren.parent;
|
delete propsOrChildren.parent;
|
||||||
delete propsOrChildren.$;
|
delete propsOrChildren.$;
|
||||||
delete propsOrChildren.dataset;
|
delete propsOrChildren.dataset;
|
||||||
delete propsOrChildren.style;
|
delete propsOrChildren.style;
|
||||||
|
|
||||||
|
if (Object.hasOwn(propsOrChildren, "for")) {
|
||||||
|
element.setAttribute("for", propsOrChildren.for)
|
||||||
|
}
|
||||||
|
|
||||||
if (style) {
|
if (style) {
|
||||||
Object.assign(element.style, style);
|
Object.assign(element.style, style);
|
||||||
}
|
}
|
||||||
@ -119,6 +126,7 @@ function dragElement(dragEl, settings) {
|
|||||||
savePos = value;
|
savePos = value;
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
function dragMouseDown(e) {
|
function dragMouseDown(e) {
|
||||||
e = e || window.event;
|
e = e || window.event;
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
@ -161,8 +169,8 @@ function dragElement(dragEl, settings) {
|
|||||||
|
|
||||||
export class ComfyDialog {
|
export class ComfyDialog {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.element = $el("div.comfy-modal", { parent: document.body }, [
|
this.element = $el("div.comfy-modal", {parent: document.body}, [
|
||||||
$el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]),
|
$el("div.comfy-modal-content", [$el("p", {$: (p) => (this.textElement = p)}), ...this.createButtons()]),
|
||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,7 +201,25 @@ export class ComfyDialog {
|
|||||||
class ComfySettingsDialog extends ComfyDialog {
|
class ComfySettingsDialog extends ComfyDialog {
|
||||||
constructor() {
|
constructor() {
|
||||||
super();
|
super();
|
||||||
this.element.classList.add("comfy-settings");
|
this.element = $el("dialog", {
|
||||||
|
id: "comfy-settings-dialog",
|
||||||
|
parent: document.body,
|
||||||
|
}, [
|
||||||
|
$el("table.comfy-modal-content.comfy-table", [
|
||||||
|
$el("caption", {textContent: "Settings"}),
|
||||||
|
$el("tbody", {$: (tbody) => (this.textElement = tbody)}),
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "Close",
|
||||||
|
style: {
|
||||||
|
cursor: "pointer",
|
||||||
|
},
|
||||||
|
onclick: () => {
|
||||||
|
this.element.close();
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
]);
|
||||||
this.settings = [];
|
this.settings = [];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,15 +234,16 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
localStorage[settingId] = JSON.stringify(value);
|
localStorage[settingId] = JSON.stringify(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", }) {
|
addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "",}) {
|
||||||
if (!id) {
|
if (!id) {
|
||||||
throw new Error("Settings must have an ID");
|
throw new Error("Settings must have an ID");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.settings.find((s) => s.id === id)) {
|
if (this.settings.find((s) => s.id === id)) {
|
||||||
throw new Error("Setting IDs must be unique");
|
throw new Error(`Setting ${id} of type ${type} must have a unique ID.`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const settingId = "Comfy.Settings." + id;
|
const settingId = `Comfy.Settings.${id}`;
|
||||||
const v = localStorage[settingId];
|
const v = localStorage[settingId];
|
||||||
let value = v == null ? defaultValue : JSON.parse(v);
|
let value = v == null ? defaultValue : JSON.parse(v);
|
||||||
|
|
||||||
@ -234,34 +261,50 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
localStorage[settingId] = JSON.stringify(v);
|
localStorage[settingId] = JSON.stringify(v);
|
||||||
value = v;
|
value = v;
|
||||||
};
|
};
|
||||||
|
value = this.getSettingValue(id, defaultValue);
|
||||||
|
|
||||||
let element;
|
let element;
|
||||||
value = this.getSettingValue(id, defaultValue);
|
const htmlID = id.replaceAll(".", "-");
|
||||||
|
|
||||||
|
const labelCell = $el("td", [
|
||||||
|
$el("label", {
|
||||||
|
for: htmlID,
|
||||||
|
classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""],
|
||||||
|
textContent: name,
|
||||||
|
})
|
||||||
|
]);
|
||||||
|
|
||||||
if (typeof type === "function") {
|
if (typeof type === "function") {
|
||||||
element = type(name, setter, value, attrs);
|
element = type(name, setter, value, attrs);
|
||||||
} else {
|
} else {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case "boolean":
|
case "boolean":
|
||||||
element = $el("div", [
|
element = $el("tr", [
|
||||||
$el("label", { textContent: name || id }, [
|
labelCell,
|
||||||
|
$el("td", [
|
||||||
$el("input", {
|
$el("input", {
|
||||||
|
id: htmlID,
|
||||||
type: "checkbox",
|
type: "checkbox",
|
||||||
checked: !!value,
|
checked: value,
|
||||||
oninput: (e) => {
|
onchange: (event) => {
|
||||||
setter(e.target.checked);
|
const isChecked = event.target.checked;
|
||||||
|
if (onChange !== undefined) {
|
||||||
|
onChange(isChecked)
|
||||||
|
}
|
||||||
|
this.setSettingValue(id, isChecked);
|
||||||
},
|
},
|
||||||
...attrs
|
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
]);
|
])
|
||||||
break;
|
break;
|
||||||
case "number":
|
case "number":
|
||||||
element = $el("div", [
|
element = $el("tr", [
|
||||||
$el("label", { textContent: name || id }, [
|
labelCell,
|
||||||
|
$el("td", [
|
||||||
$el("input", {
|
$el("input", {
|
||||||
type,
|
type,
|
||||||
value,
|
value,
|
||||||
|
id: htmlID,
|
||||||
oninput: (e) => {
|
oninput: (e) => {
|
||||||
setter(e.target.value);
|
setter(e.target.value);
|
||||||
},
|
},
|
||||||
@ -271,46 +314,62 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
]);
|
]);
|
||||||
break;
|
break;
|
||||||
case "slider":
|
case "slider":
|
||||||
element = $el("div", [
|
element = $el("tr", [
|
||||||
$el("label", { textContent: name }, [
|
labelCell,
|
||||||
$el("input", {
|
$el("td", [
|
||||||
type: "range",
|
$el("div", {
|
||||||
value,
|
style: {
|
||||||
oninput: (e) => {
|
display: "grid",
|
||||||
setter(e.target.value);
|
gridAutoFlow: "column",
|
||||||
e.target.nextElementSibling.value = e.target.value;
|
|
||||||
},
|
},
|
||||||
...attrs
|
}, [
|
||||||
}),
|
$el("input", {
|
||||||
$el("input", {
|
...attrs,
|
||||||
type: "number",
|
value,
|
||||||
value,
|
type: "range",
|
||||||
oninput: (e) => {
|
oninput: (e) => {
|
||||||
setter(e.target.value);
|
setter(e.target.value);
|
||||||
e.target.previousElementSibling.value = e.target.value;
|
e.target.nextElementSibling.value = e.target.value;
|
||||||
},
|
},
|
||||||
...attrs
|
}),
|
||||||
}),
|
$el("input", {
|
||||||
|
...attrs,
|
||||||
|
value,
|
||||||
|
id: htmlID,
|
||||||
|
type: "number",
|
||||||
|
style: {maxWidth: "4rem"},
|
||||||
|
oninput: (e) => {
|
||||||
|
setter(e.target.value);
|
||||||
|
e.target.previousElementSibling.value = e.target.value;
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
]),
|
||||||
]),
|
]),
|
||||||
]);
|
]);
|
||||||
break;
|
break;
|
||||||
|
case "text":
|
||||||
default:
|
default:
|
||||||
console.warn("Unsupported setting type, defaulting to text");
|
if (type !== "text") {
|
||||||
element = $el("div", [
|
console.warn(`Unsupported setting type '${type}, defaulting to text`);
|
||||||
$el("label", { textContent: name || id }, [
|
}
|
||||||
|
|
||||||
|
element = $el("tr", [
|
||||||
|
labelCell,
|
||||||
|
$el("td", [
|
||||||
$el("input", {
|
$el("input", {
|
||||||
value,
|
value,
|
||||||
|
id: htmlID,
|
||||||
oninput: (e) => {
|
oninput: (e) => {
|
||||||
setter(e.target.value);
|
setter(e.target.value);
|
||||||
},
|
},
|
||||||
...attrs
|
...attrs,
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
]);
|
]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(tooltip) {
|
if (tooltip) {
|
||||||
element.title = tooltip;
|
element.title = tooltip;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -330,13 +389,16 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
}
|
}
|
||||||
|
|
||||||
show() {
|
show() {
|
||||||
super.show();
|
this.textElement.replaceChildren(
|
||||||
Object.assign(this.textElement.style, {
|
$el("tr", {
|
||||||
display: "flex",
|
style: {display: "none"},
|
||||||
flexDirection: "column",
|
}, [
|
||||||
gap: "10px"
|
$el("th"),
|
||||||
});
|
$el("th", {style: {width: "33%"}})
|
||||||
this.textElement.replaceChildren(...this.settings.map((s) => s.render()));
|
]),
|
||||||
|
...this.settings.map((s) => s.render()),
|
||||||
|
)
|
||||||
|
this.element.showModal();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -369,7 +431,7 @@ class ComfyList {
|
|||||||
name: "Delete",
|
name: "Delete",
|
||||||
cb: () => api.deleteItem(this.#type, item.prompt[1]),
|
cb: () => api.deleteItem(this.#type, item.prompt[1]),
|
||||||
};
|
};
|
||||||
return $el("div", { textContent: item.prompt[0] + ": " }, [
|
return $el("div", {textContent: item.prompt[0] + ": "}, [
|
||||||
$el("button", {
|
$el("button", {
|
||||||
textContent: "Load",
|
textContent: "Load",
|
||||||
onclick: () => {
|
onclick: () => {
|
||||||
@ -398,7 +460,7 @@ class ComfyList {
|
|||||||
await this.load();
|
await this.load();
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
$el("button", { textContent: "Refresh", onclick: () => this.load() }),
|
$el("button", {textContent: "Refresh", onclick: () => this.load()}),
|
||||||
])
|
])
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -475,8 +537,8 @@ export class ComfyUI {
|
|||||||
*/
|
*/
|
||||||
const previewImage = this.settings.addSetting({
|
const previewImage = this.settings.addSetting({
|
||||||
id: "Comfy.PreviewFormat",
|
id: "Comfy.PreviewFormat",
|
||||||
name: "When displaying a preview in the image widget, convert it to a lightweight image. (webp, jpeg, webp;50, ...)",
|
name: "When displaying a preview in the image widget, convert it to a lightweight image, e.g. webp, jpeg, webp;50, etc.",
|
||||||
type: "string",
|
type: "text",
|
||||||
defaultValue: "",
|
defaultValue: "",
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -484,18 +546,25 @@ export class ComfyUI {
|
|||||||
id: "comfy-file-input",
|
id: "comfy-file-input",
|
||||||
type: "file",
|
type: "file",
|
||||||
accept: ".json,image/png,.latent",
|
accept: ".json,image/png,.latent",
|
||||||
style: { display: "none" },
|
style: {display: "none"},
|
||||||
parent: document.body,
|
parent: document.body,
|
||||||
onchange: () => {
|
onchange: () => {
|
||||||
app.handleFile(fileInput.files[0]);
|
app.handleFile(fileInput.files[0]);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
|
this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [
|
||||||
$el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [
|
$el("div.drag-handle", {
|
||||||
|
style: {
|
||||||
|
overflow: "hidden",
|
||||||
|
position: "relative",
|
||||||
|
width: "100%",
|
||||||
|
cursor: "default"
|
||||||
|
}
|
||||||
|
}, [
|
||||||
$el("span.drag-handle"),
|
$el("span.drag-handle"),
|
||||||
$el("span", { $: (q) => (this.queueSize = q) }),
|
$el("span", {$: (q) => (this.queueSize = q)}),
|
||||||
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
|
$el("button.comfy-settings-btn", {textContent: "⚙️", onclick: () => this.settings.show()}),
|
||||||
]),
|
]),
|
||||||
$el("button.comfy-queue-btn", {
|
$el("button.comfy-queue-btn", {
|
||||||
id: "queue-button",
|
id: "queue-button",
|
||||||
@ -503,7 +572,7 @@ export class ComfyUI {
|
|||||||
onclick: () => app.queuePrompt(0, this.batchCount),
|
onclick: () => app.queuePrompt(0, this.batchCount),
|
||||||
}),
|
}),
|
||||||
$el("div", {}, [
|
$el("div", {}, [
|
||||||
$el("label", { innerHTML: "Extra options" }, [
|
$el("label", {innerHTML: "Extra options"}, [
|
||||||
$el("input", {
|
$el("input", {
|
||||||
type: "checkbox",
|
type: "checkbox",
|
||||||
onchange: (i) => {
|
onchange: (i) => {
|
||||||
@ -514,14 +583,14 @@ export class ComfyUI {
|
|||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
]),
|
]),
|
||||||
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" } }, [
|
$el("div", {id: "extraOptions", style: {width: "100%", display: "none"}}, [
|
||||||
$el("label", { innerHTML: "Batch count" }, [
|
$el("label", {innerHTML: "Batch count"}, [
|
||||||
$el("input", {
|
$el("input", {
|
||||||
id: "batchCountInputNumber",
|
id: "batchCountInputNumber",
|
||||||
type: "number",
|
type: "number",
|
||||||
value: this.batchCount,
|
value: this.batchCount,
|
||||||
min: "1",
|
min: "1",
|
||||||
style: { width: "35%", "margin-left": "0.4em" },
|
style: {width: "35%", "margin-left": "0.4em"},
|
||||||
oninput: (i) => {
|
oninput: (i) => {
|
||||||
this.batchCount = i.target.value;
|
this.batchCount = i.target.value;
|
||||||
document.getElementById("batchCountInputRange").value = this.batchCount;
|
document.getElementById("batchCountInputRange").value = this.batchCount;
|
||||||
@ -547,7 +616,11 @@ export class ComfyUI {
|
|||||||
]),
|
]),
|
||||||
]),
|
]),
|
||||||
$el("div.comfy-menu-btns", [
|
$el("div.comfy-menu-btns", [
|
||||||
$el("button", { id: "queue-front-button", textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }),
|
$el("button", {
|
||||||
|
id: "queue-front-button",
|
||||||
|
textContent: "Queue Front",
|
||||||
|
onclick: () => app.queuePrompt(-1, this.batchCount)
|
||||||
|
}),
|
||||||
$el("button", {
|
$el("button", {
|
||||||
$: (b) => (this.queue.button = b),
|
$: (b) => (this.queue.button = b),
|
||||||
id: "comfy-view-queue-button",
|
id: "comfy-view-queue-button",
|
||||||
@ -582,12 +655,12 @@ export class ComfyUI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
|
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
|
||||||
const blob = new Blob([json], { type: "application/json" });
|
const blob = new Blob([json], {type: "application/json"});
|
||||||
const url = URL.createObjectURL(blob);
|
const url = URL.createObjectURL(blob);
|
||||||
const a = $el("a", {
|
const a = $el("a", {
|
||||||
href: url,
|
href: url,
|
||||||
download: filename,
|
download: filename,
|
||||||
style: { display: "none" },
|
style: {display: "none"},
|
||||||
parent: document.body,
|
parent: document.body,
|
||||||
});
|
});
|
||||||
a.click();
|
a.click();
|
||||||
@ -597,25 +670,33 @@ export class ComfyUI {
|
|||||||
}, 0);
|
}, 0);
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
$el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }),
|
$el("button", {id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click()}),
|
||||||
$el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
|
$el("button", {
|
||||||
$el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }),
|
id: "comfy-refresh-button",
|
||||||
$el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => {
|
textContent: "Refresh",
|
||||||
if (!confirmClear.value || confirm("Clear workflow?")) {
|
onclick: () => app.refreshComboInNodes()
|
||||||
app.clean();
|
}),
|
||||||
app.graph.clear();
|
$el("button", {id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace()}),
|
||||||
|
$el("button", {
|
||||||
|
id: "comfy-clear-button", textContent: "Clear", onclick: () => {
|
||||||
|
if (!confirmClear.value || confirm("Clear workflow?")) {
|
||||||
|
app.clean();
|
||||||
|
app.graph.clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}}),
|
}),
|
||||||
$el("button", { id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
|
$el("button", {
|
||||||
if (!confirmClear.value || confirm("Load default workflow?")) {
|
id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
|
||||||
app.loadGraphData()
|
if (!confirmClear.value || confirm("Load default workflow?")) {
|
||||||
|
app.loadGraphData()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
dragElement(this.menuContainer, this.settings);
|
dragElement(this.menuContainer, this.settings);
|
||||||
|
|
||||||
this.setStatus({ exec_info: { queue_remaining: "X" } });
|
this.setStatus({exec_info: {queue_remaining: "X"}});
|
||||||
}
|
}
|
||||||
|
|
||||||
setStatus(status) {
|
setStatus(status) {
|
||||||
|
|||||||
@ -8,6 +8,8 @@
|
|||||||
--drag-text: #ccc;
|
--drag-text: #ccc;
|
||||||
--error-text: #ff4444;
|
--error-text: #ff4444;
|
||||||
--border-color: #4e4e4e;
|
--border-color: #4e4e4e;
|
||||||
|
--tr-even-bg-color: #222;
|
||||||
|
--tr-odd-bg-color: #353535;
|
||||||
}
|
}
|
||||||
|
|
||||||
@media (prefers-color-scheme: dark) {
|
@media (prefers-color-scheme: dark) {
|
||||||
@ -220,7 +222,7 @@ button.comfy-queue-btn {
|
|||||||
margin: 6px 0 !important;
|
margin: 6px 0 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-modal.comfy-settings,
|
.comfy-modal.comfy-settings,
|
||||||
.comfy-modal.comfy-manage-templates {
|
.comfy-modal.comfy-manage-templates {
|
||||||
text-align: center;
|
text-align: center;
|
||||||
font-family: sans-serif;
|
font-family: sans-serif;
|
||||||
@ -246,6 +248,11 @@ button.comfy-queue-btn {
|
|||||||
font-size: inherit;
|
font-size: inherit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.comfy-tooltip-indicator {
|
||||||
|
text-decoration: underline;
|
||||||
|
text-decoration-style: dashed;
|
||||||
|
}
|
||||||
|
|
||||||
@media only screen and (max-height: 850px) {
|
@media only screen and (max-height: 850px) {
|
||||||
.comfy-menu {
|
.comfy-menu {
|
||||||
top: 0 !important;
|
top: 0 !important;
|
||||||
@ -254,8 +261,9 @@ button.comfy-queue-btn {
|
|||||||
right: 0 !important;
|
right: 0 !important;
|
||||||
border-radius: 0;
|
border-radius: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-menu span.drag-handle {
|
.comfy-menu span.drag-handle {
|
||||||
visibility:hidden
|
visibility: hidden
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -287,11 +295,75 @@ button.comfy-queue-btn {
|
|||||||
border-radius: 12px 0 0 12px;
|
border-radius: 12px 0 0 12px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Dialogs */
|
||||||
|
|
||||||
|
dialog {
|
||||||
|
box-shadow: 0 0 20px #888888;
|
||||||
|
}
|
||||||
|
|
||||||
|
dialog::backdrop {
|
||||||
|
background: rgba(0, 0, 0, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#comfy-settings-dialog {
|
||||||
|
padding: 0;
|
||||||
|
width: 41rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
#comfy-settings-dialog tr > td:first-child {
|
||||||
|
text-align: right;
|
||||||
|
}
|
||||||
|
|
||||||
|
#comfy-settings-dialog button {
|
||||||
|
background-color: var(--bg-color);
|
||||||
|
border: 1px var(--border-color) solid;
|
||||||
|
border-radius: 0;
|
||||||
|
color: var(--input-text);
|
||||||
|
font-size: 1rem;
|
||||||
|
padding: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
#comfy-settings-dialog button:hover {
|
||||||
|
background-color: var(--tr-odd-bg-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* General CSS for tables */
|
||||||
|
|
||||||
|
.comfy-table {
|
||||||
|
border-collapse: collapse;
|
||||||
|
color: var(--input-text);
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-table caption {
|
||||||
|
background-color: var(--bg-color);
|
||||||
|
color: var(--input-text);
|
||||||
|
font-size: 1rem;
|
||||||
|
font-weight: bold;
|
||||||
|
padding: 8px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-table tr:nth-child(even) {
|
||||||
|
background-color: var(--tr-even-bg-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-table tr:nth-child(odd) {
|
||||||
|
background-color: var(--tr-odd-bg-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.comfy-table td,
|
||||||
|
.comfy-table th {
|
||||||
|
border: 1px solid var(--border-color);
|
||||||
|
padding: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
/* Context menu */
|
/* Context menu */
|
||||||
|
|
||||||
.litegraph .dialog {
|
.litegraph .dialog {
|
||||||
z-index: 1;
|
z-index: 1;
|
||||||
font-family: Arial, sans-serif;
|
font-family: Arial, sans-serif;
|
||||||
}
|
}
|
||||||
|
|
||||||
.litegraph .litemenu-entry.has_submenu {
|
.litegraph .litemenu-entry.has_submenu {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user