mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-15 14:17:40 +08:00
merging master into filesubflows and fixing merge conflicts
This commit is contained in:
commit
32f7fd663a
@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x and SDXL
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||
@ -30,6 +30,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
|
||||
@ -27,7 +27,6 @@ class ControlNet(nn.Module):
|
||||
model_channels,
|
||||
hint_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
@ -52,8 +51,10 @@ class ControlNet(nn.Module):
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
operations=comfy.ops,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||
@ -79,10 +80,7 @@ class ControlNet(nn.Module):
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
if transformer_depth_middle is None:
|
||||
transformer_depth_middle = transformer_depth[-1]
|
||||
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
@ -90,18 +88,16 @@ class ControlNet(nn.Module):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
transformer_depth = transformer_depth[:]
|
||||
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
@ -180,11 +176,14 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
operations=operations
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
@ -201,9 +200,9 @@ class ControlNet(nn.Module):
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, operations=operations
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@ -223,11 +222,13 @@ class ControlNet(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -245,7 +246,7 @@ class ControlNet(nn.Module):
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
mid_block = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
@ -253,12 +254,15 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
)]
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, operations=operations
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
@ -267,9 +271,11 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
)
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
||||
self._feature_size += ch
|
||||
|
||||
|
||||
@ -36,6 +36,8 @@ parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
||||
@ -60,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
|
||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
||||
|
||||
fpte_group = parser.add_mutually_exclusive_group()
|
||||
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
|
||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
|
||||
from .utils import load_torch_file, transformers_convert
|
||||
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
|
||||
from .utils import load_torch_file, transformers_convert, common_upscale
|
||||
import os
|
||||
import torch
|
||||
import contextlib
|
||||
@ -7,6 +7,18 @@ import contextlib
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
scale = (size / min(image.shape[1], image.shape[2]))
|
||||
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
@ -23,25 +35,12 @@ class ClipVisionModel():
|
||||
self.model.to(self.dtype)
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.processor = CLIPImageProcessor(crop_size=224,
|
||||
do_center_crop=True,
|
||||
do_convert_rgb=True,
|
||||
do_normalize=True,
|
||||
do_resize=True,
|
||||
image_mean=[ 0.48145466,0.4578275,0.40821073],
|
||||
image_std=[0.26862954,0.26130258,0.27577711],
|
||||
resample=3, #bicubic
|
||||
size=224)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def encode_image(self, image):
|
||||
img = torch.clip((255. * image), 0, 255).round().int()
|
||||
img = list(map(lambda a: a, img))
|
||||
inputs = self.processor(images=img, return_tensors="pt")
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = inputs['pixel_values'].to(self.load_device)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device))
|
||||
|
||||
if self.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
|
||||
79
comfy/conds.py
Normal file
79
comfy/conds.py
Normal file
@ -0,0 +1,79 @@
|
||||
import enum
|
||||
import torch
|
||||
import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def _copy_with(self, cond):
|
||||
return self.__class__(cond)
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond.shape != other.cond.shape:
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
conds = [self.cond]
|
||||
for x in others:
|
||||
conds.append(x.cond)
|
||||
return torch.cat(conds)
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
|
||||
|
||||
class CONDCrossAttn(CONDRegular):
|
||||
def can_concat(self, other):
|
||||
s1 = self.cond.shape
|
||||
s2 = other.cond.shape
|
||||
if s1 != s2:
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
conds = [self.cond]
|
||||
crossattn_max_len = self.cond.shape[1]
|
||||
for x in others:
|
||||
c = x.cond
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
conds.append(c)
|
||||
|
||||
out = []
|
||||
for c in conds:
|
||||
if c.shape[1] < crossattn_max_len:
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
out.append(c)
|
||||
return torch.cat(out)
|
||||
|
||||
class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond != other.cond:
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
return self.cond
|
||||
@ -33,7 +33,7 @@ class ControlBase:
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
self.timestep_percent_range = (1.0, 0.0)
|
||||
self.timestep_percent_range = (0.0, 1.0)
|
||||
self.timestep_range = None
|
||||
|
||||
if device is None:
|
||||
@ -42,7 +42,7 @@ class ControlBase:
|
||||
self.previous_controlnet = None
|
||||
self.global_average_pooling = False
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
||||
self.cond_hint_original = cond_hint
|
||||
self.strength = strength
|
||||
self.timestep_percent_range = timestep_percent_range
|
||||
@ -132,6 +132,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.model_sampling_current = None
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
control_prev = None
|
||||
@ -156,10 +157,13 @@ class ControlNet(ControlBase):
|
||||
|
||||
|
||||
context = cond['c_crossattn']
|
||||
y = cond.get('c_adm', None)
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(self.control_model.dtype)
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
|
||||
out.append(self.control_model_wrapped)
|
||||
return out
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.model_sampling_current = model.model_sampling
|
||||
|
||||
def cleanup(self):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
|
||||
@ -852,7 +852,13 @@ class SigmaConvert:
|
||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||
return log_mean_coeff - log_std
|
||||
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||
def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
||||
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
|
||||
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
||||
return (input - model(input, sigma_in, **kwargs)) / sigma
|
||||
|
||||
|
||||
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||
timesteps = sigmas.clone()
|
||||
if sigmas[-1] == 0:
|
||||
timesteps = sigmas[:]
|
||||
@ -874,14 +880,14 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
||||
model_type = "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
model.predict_eps_sigma,
|
||||
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
|
||||
ns,
|
||||
model_type=model_type,
|
||||
guidance_type="uncond",
|
||||
model_kwargs=extra_args,
|
||||
)
|
||||
|
||||
order = min(3, len(timesteps) - 1)
|
||||
order = min(3, len(timesteps) - 2)
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
||||
x /= ns.marginal_alpha(timesteps[-1])
|
||||
|
||||
@ -1,194 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from . import sampling, utils
|
||||
|
||||
|
||||
class VDenoiser(nn.Module):
|
||||
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
|
||||
|
||||
def __init__(self, inner_model):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
return sigma.atan() / math.pi * 2
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
return (t * math.pi / 2).tan()
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DiscreteSchedule(nn.Module):
|
||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
||||
levels."""
|
||||
|
||||
def __init__(self, sigmas, quantize):
|
||||
super().__init__()
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
self.quantize = quantize
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def get_sigmas(self, n=None):
|
||||
if n is None:
|
||||
return sampling.append_zero(self.sigmas.flip(0))
|
||||
t_max = len(self.sigmas) - 1
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_discrete_timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
if quantize:
|
||||
return self.sigma_to_discrete_timestep(sigma)
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
return t.view(sigma.shape)
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t-low_idx if t.device.type == 'mps' else t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
||||
if t.dtype != torch.int64 and t.dtype != torch.int32:
|
||||
t = t.round()
|
||||
sigma = self.t_to_sigma(t)
|
||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||
|
||||
def predict_eps_sigma(self, input, sigma, **kwargs):
|
||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
noise)."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_out, c_in
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return (eps - noise).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out
|
||||
|
||||
|
||||
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for OpenAI diffusion models."""
|
||||
|
||||
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
|
||||
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
|
||||
super().__init__(model, alphas_cumprod, quantize=quantize)
|
||||
self.has_learned_sigmas = has_learned_sigmas
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
model_output = self.inner_model(*args, **kwargs)
|
||||
if self.has_learned_sigmas:
|
||||
return model_output.chunk(2, dim=1)[0]
|
||||
return model_output
|
||||
|
||||
|
||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
||||
|
||||
|
||||
class DiscreteVDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output v."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def get_v(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models that output v."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond)
|
||||
@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
||||
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
||||
return mu
|
||||
|
||||
|
||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@ -737,3 +736,75 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
|
||||
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
x = denoised
|
||||
if sigmas[i + 1] > 0:
|
||||
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
s_end = sigmas[-1]
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == s_end:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
elif sigmas[i + 2] == s_end:
|
||||
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
|
||||
w = 2 * sigmas[0]
|
||||
w2 = sigmas[i+1]/w
|
||||
w1 = 1 - w2
|
||||
|
||||
d_prime = d * w1 + d_2 * w2
|
||||
|
||||
|
||||
x = x + d_prime * dt
|
||||
|
||||
else:
|
||||
# Heun++
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
dt_2 = sigmas[i + 2] - sigmas[i + 1]
|
||||
|
||||
x_3 = x_2 + d_2 * dt_2
|
||||
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
|
||||
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
|
||||
|
||||
w = 3 * sigmas[0]
|
||||
w2 = sigmas[i + 1] / w
|
||||
w3 = sigmas[i + 2] / w
|
||||
w1 = 1 - w2 - w3
|
||||
|
||||
d_prime = w1 * d + w2 * d_2 + w3 * d_3
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
@ -1,418 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
self.parameterization = kwargs.get("parameterization", "eps")
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != self.device:
|
||||
attr = attr.float().to(self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose)
|
||||
|
||||
def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = torch.tensor(ddim_timesteps)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
||||
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_custom(self,
|
||||
ddim_timesteps,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
denoise_function=None,
|
||||
extra_args=None,
|
||||
to_zero=True,
|
||||
end_step=None,
|
||||
disable_pbar=False,
|
||||
**kwargs
|
||||
):
|
||||
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
|
||||
samples, intermediates = self.ddim_sampling(conditioning, x_T.shape,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule,
|
||||
denoise_function=denoise_function,
|
||||
extra_args=extra_args,
|
||||
to_zero=to_zero,
|
||||
end_step=end_step,
|
||||
disable_pbar=disable_pbar
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule,
|
||||
denoise_function=None,
|
||||
extra_args=None
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x_start)
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
|
||||
device = self.model.alphas_cumprod.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
assert len(ucg_schedule) == len(time_range)
|
||||
unconditional_guidance_scale = ucg_schedule[i]
|
||||
|
||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
|
||||
img, pred_x0 = outs
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
if to_zero:
|
||||
img = pred_x0
|
||||
else:
|
||||
if ddim_use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
img /= sqrt_alphas_cumprod[index - 1]
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None, denoise_function=None, extra_args=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if denoise_function is not None:
|
||||
model_output = denoise_function(x, t, **extra_args)
|
||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
unconditional_conditioning[k][i],
|
||||
c[k][i]]) for i in range(len(c[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([
|
||||
unconditional_conditioning[k],
|
||||
c[k]])
|
||||
elif isinstance(c, list):
|
||||
c_in = list()
|
||||
assert isinstance(unconditional_conditioning, list)
|
||||
for i in range(len(c)):
|
||||
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||
|
||||
if self.parameterization == "v":
|
||||
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
||||
if unconditional_guidance_scale == 1.:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c))), 2)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = alphas_next[i].sqrt() * (
|
||||
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if return_intermediates and i % (
|
||||
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
if callback: callback(i)
|
||||
|
||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({'intermediates': intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
if max_denoise:
|
||||
noise_multiplier = 1.0
|
||||
else:
|
||||
noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False, callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback: callback(i)
|
||||
return x_dec
|
||||
@ -1 +0,0 @@
|
||||
from .sampler import DPMSolverSampler
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,96 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
import torch
|
||||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
MODEL_TYPES = {
|
||||
"eps": "noise",
|
||||
"v": "v"
|
||||
}
|
||||
|
||||
|
||||
class DPMSolverSampler(object):
|
||||
def __init__(self, model, device=torch.device("cuda"), **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.device = device
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != self.device:
|
||||
attr = attr.to(self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
if isinstance(ctmp, torch.Tensor):
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if isinstance(conditioning, torch.Tensor):
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
|
||||
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
||||
|
||||
device = self.model.betas.device
|
||||
if x_T is None:
|
||||
img = torch.randn(size, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||
|
||||
model_fn = model_wrapper(
|
||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||
ns,
|
||||
model_type=MODEL_TYPES[self.model.parameterization],
|
||||
guidance_type="classifier-free",
|
||||
condition=conditioning,
|
||||
unconditional_condition=unconditional_conditioning,
|
||||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
||||
lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
||||
@ -1,245 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != self.device:
|
||||
attr = attr.to(self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
@ -1,22 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
||||
@ -5,8 +5,10 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Any
|
||||
from functools import partial
|
||||
|
||||
from .diffusionmodules.util import checkpoint
|
||||
|
||||
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
@ -160,32 +162,19 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
|
||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||
|
||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||
|
||||
kv_chunk_size_min = None
|
||||
kv_chunk_size = None
|
||||
query_chunk_size = None
|
||||
|
||||
#not sure at all about the math here
|
||||
#TODO: tweak this
|
||||
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 4
|
||||
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 2
|
||||
else:
|
||||
query_chunk_size_x = 1024
|
||||
kv_chunk_size_min_x = None
|
||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
||||
if kv_chunk_size_x < 1024:
|
||||
kv_chunk_size_x = None
|
||||
for x in [4096, 2048, 1024, 512, 256]:
|
||||
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
|
||||
if count >= k_tokens:
|
||||
kv_chunk_size = k_tokens
|
||||
query_chunk_size = x
|
||||
break
|
||||
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
else:
|
||||
query_chunk_size = query_chunk_size_x
|
||||
kv_chunk_size = kv_chunk_size_x
|
||||
kv_chunk_size_min = kv_chunk_size_min_x
|
||||
if query_chunk_size is None:
|
||||
query_chunk_size = 512
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
@ -222,9 +211,14 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
element_size = 4
|
||||
else:
|
||||
element_size = q.element_size()
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||
modifier = 3
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
@ -252,10 +246,10 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||
first_op_done = True
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
first_op_done = True
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
@ -284,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
)
|
||||
return r1
|
||||
|
||||
BROKEN_XFORMERS = False
|
||||
try:
|
||||
x_vers = xformers.__version__
|
||||
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
|
||||
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
if BROKEN_XFORMERS:
|
||||
if b * heads > 65535:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
@ -378,53 +383,72 @@ class CrossAttention(nn.Module):
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
|
||||
super().__init__()
|
||||
|
||||
self.ff_in = ff_in or inner_dim is not None
|
||||
if inner_dim is None:
|
||||
inner_dim = dim
|
||||
|
||||
self.is_res = inner_dim == dim
|
||||
|
||||
if self.ff_in:
|
||||
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
if switch_temporal_ca_to_sa:
|
||||
raise ValueError
|
||||
else:
|
||||
self.attn2 = None
|
||||
else:
|
||||
context_dim_attn2 = None
|
||||
if not switch_temporal_ca_to_sa:
|
||||
context_dim_attn2 = context_dim
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.checkpoint = checkpoint
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
extra_options = {}
|
||||
block = None
|
||||
block_index = 0
|
||||
if "current_index" in transformer_options:
|
||||
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||
if "block_index" in transformer_options:
|
||||
block_index = transformer_options["block_index"]
|
||||
extra_options["block_index"] = block_index
|
||||
if "original_shape" in transformer_options:
|
||||
extra_options["original_shape"] = transformer_options["original_shape"]
|
||||
if "block" in transformer_options:
|
||||
block = transformer_options["block"]
|
||||
extra_options["block"] = block
|
||||
if "cond_or_uncond" in transformer_options:
|
||||
extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"]
|
||||
if "patches" in transformer_options:
|
||||
transformer_patches = transformer_options["patches"]
|
||||
else:
|
||||
transformer_patches = {}
|
||||
block = transformer_options.get("block", None)
|
||||
block_index = transformer_options.get("block_index", 0)
|
||||
transformer_patches = {}
|
||||
transformer_patches_replace = {}
|
||||
|
||||
for k in transformer_options:
|
||||
if k == "patches":
|
||||
transformer_patches = transformer_options[k]
|
||||
elif k == "patches_replace":
|
||||
transformer_patches_replace = transformer_options[k]
|
||||
else:
|
||||
extra_options[k] = transformer_options[k]
|
||||
|
||||
extra_options["n_heads"] = self.n_heads
|
||||
extra_options["dim_head"] = self.d_head
|
||||
|
||||
if "patches_replace" in transformer_options:
|
||||
transformer_patches_replace = transformer_options["patches_replace"]
|
||||
else:
|
||||
transformer_patches_replace = {}
|
||||
if self.ff_in:
|
||||
x_skip = x
|
||||
x = self.ff_in(self.norm_in(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
n = self.norm1(x)
|
||||
if self.disable_self_attn:
|
||||
@ -473,31 +497,34 @@ class BasicTransformerBlock(nn.Module):
|
||||
for p in patch:
|
||||
x = p(x, extra_options)
|
||||
|
||||
n = self.norm2(x)
|
||||
|
||||
context_attn2 = context
|
||||
value_attn2 = None
|
||||
if "attn2_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_patch"]
|
||||
value_attn2 = context_attn2
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||
|
||||
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||
block_attn2 = transformer_block
|
||||
if block_attn2 not in attn2_replace_patch:
|
||||
block_attn2 = block
|
||||
|
||||
if block_attn2 in attn2_replace_patch:
|
||||
if value_attn2 is None:
|
||||
if self.attn2 is not None:
|
||||
n = self.norm2(x)
|
||||
if self.switch_temporal_ca_to_sa:
|
||||
context_attn2 = n
|
||||
else:
|
||||
context_attn2 = context
|
||||
value_attn2 = None
|
||||
if "attn2_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_patch"]
|
||||
value_attn2 = context_attn2
|
||||
n = self.attn2.to_q(n)
|
||||
context_attn2 = self.attn2.to_k(context_attn2)
|
||||
value_attn2 = self.attn2.to_v(value_attn2)
|
||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||
n = self.attn2.to_out(n)
|
||||
else:
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||
|
||||
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||
block_attn2 = transformer_block
|
||||
if block_attn2 not in attn2_replace_patch:
|
||||
block_attn2 = block
|
||||
|
||||
if block_attn2 in attn2_replace_patch:
|
||||
if value_attn2 is None:
|
||||
value_attn2 = context_attn2
|
||||
n = self.attn2.to_q(n)
|
||||
context_attn2 = self.attn2.to_k(context_attn2)
|
||||
value_attn2 = self.attn2.to_v(value_attn2)
|
||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||
n = self.attn2.to_out(n)
|
||||
else:
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
|
||||
if "attn2_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_output_patch"]
|
||||
@ -505,7 +532,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
if self.is_res:
|
||||
x_skip = x
|
||||
x = self.ff(self.norm3(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ -573,3 +605,164 @@ class SpatialTransformer(nn.Module):
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class SpatialVideoTransformer(SpatialTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
use_linear=False,
|
||||
context_dim=None,
|
||||
use_spatial_context=False,
|
||||
timesteps=None,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
time_context_dim=None,
|
||||
ff_in=False,
|
||||
checkpoint=False,
|
||||
time_depth=1,
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
max_time_embed_period: int = 10000,
|
||||
dtype=None, device=None, operations=comfy.ops
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=depth,
|
||||
dropout=dropout,
|
||||
use_checkpoint=checkpoint,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
disable_self_attn=disable_self_attn,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.time_depth = time_depth
|
||||
self.depth = depth
|
||||
self.max_time_embed_period = max_time_embed_period
|
||||
|
||||
time_mix_d_head = d_head
|
||||
n_time_mix_heads = n_heads
|
||||
|
||||
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
||||
|
||||
inner_dim = n_heads * d_head
|
||||
if use_spatial_context:
|
||||
time_context_dim = context_dim
|
||||
|
||||
self.time_stack = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_time_mix_heads,
|
||||
time_mix_d_head,
|
||||
dropout=dropout,
|
||||
context_dim=time_context_dim,
|
||||
# timesteps=timesteps,
|
||||
checkpoint=checkpoint,
|
||||
ff_in=ff_in,
|
||||
inner_dim=time_mix_inner_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
assert len(self.time_stack) == len(self.transformer_blocks)
|
||||
|
||||
self.use_spatial_context = use_spatial_context
|
||||
self.in_channels = in_channels
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.time_pos_embed = nn.Sequential(
|
||||
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor, merge_strategy=merge_strategy
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
time_context: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[int] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
transformer_options={}
|
||||
) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
x_in = x
|
||||
spatial_context = None
|
||||
if exists(context):
|
||||
spatial_context = context
|
||||
|
||||
if self.use_spatial_context:
|
||||
assert (
|
||||
context.ndim == 3
|
||||
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||
|
||||
if time_context is None:
|
||||
time_context = context
|
||||
time_context_first_timestep = time_context[::timesteps]
|
||||
time_context = repeat(
|
||||
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
||||
)
|
||||
elif time_context is not None and not self.use_spatial_context:
|
||||
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
||||
if time_context.ndim == 2:
|
||||
time_context = rearrange(time_context, "b c -> b 1 c")
|
||||
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
for it_, (block, mix_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.time_stack)
|
||||
):
|
||||
transformer_options["block_index"] = it_
|
||||
x = block(
|
||||
x,
|
||||
context=spatial_context,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x_mix = x
|
||||
x_mix = x_mix + emb
|
||||
|
||||
B, S, C = x_mix.shape
|
||||
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
||||
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
||||
x_mix = rearrange(
|
||||
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||
)
|
||||
|
||||
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
|
||||
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -5,6 +5,8 @@ import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from functools import partial
|
||||
|
||||
from .util import (
|
||||
checkpoint,
|
||||
@ -12,8 +14,9 @@ from .util import (
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
AlphaBlender,
|
||||
)
|
||||
from ..attention import SpatialTransformer
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from comfy.ldm.util import exists
|
||||
import comfy.ops
|
||||
|
||||
@ -28,6 +31,26 @@ class TimestepBlock(nn.Module):
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||
for layer in ts:
|
||||
if isinstance(layer, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialVideoTransformer):
|
||||
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
|
||||
if "transformer_index" in transformer_options:
|
||||
transformer_options["transformer_index"] += 1
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
if "transformer_index" in transformer_options:
|
||||
transformer_options["transformer_index"] += 1
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
@ -35,31 +58,8 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||
for layer in ts:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
transformer_options["current_index"] += 1
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
def forward(self, *args, **kwargs):
|
||||
return forward_timestep_embed(self, *args, **kwargs)
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
@ -154,6 +154,9 @@ class ResBlock(TimestepBlock):
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
exchange_temb_dims=False,
|
||||
skip_t_emb=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=comfy.ops
|
||||
@ -166,11 +169,17 @@ class ResBlock(TimestepBlock):
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.exchange_temb_dims = exchange_temb_dims
|
||||
|
||||
if isinstance(kernel_size, list):
|
||||
padding = [k // 2 for k in kernel_size]
|
||||
else:
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
|
||||
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
@ -184,19 +193,24 @@ class ResBlock(TimestepBlock):
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.skip_t_emb = skip_t_emb
|
||||
if self.skip_t_emb:
|
||||
self.emb_layers = None
|
||||
self.exchange_temb_dims = False
|
||||
else:
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||
),
|
||||
)
|
||||
|
||||
@ -204,7 +218,7 @@ class ResBlock(TimestepBlock):
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = operations.conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
|
||||
dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||
@ -230,19 +244,110 @@ class ResBlock(TimestepBlock):
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
|
||||
emb_out = None
|
||||
if not self.skip_t_emb:
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_norm(h)
|
||||
if emb_out is not None:
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h *= (1 + scale)
|
||||
h += shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
if emb_out is not None:
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class VideoResBlock(ResBlock):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
emb_channels: int,
|
||||
dropout: float,
|
||||
video_kernel_size=3,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
out_channels=None,
|
||||
use_conv: bool = False,
|
||||
use_scale_shift_norm: bool = False,
|
||||
dims: int = 2,
|
||||
use_checkpoint: bool = False,
|
||||
up: bool = False,
|
||||
down: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=comfy.ops
|
||||
):
|
||||
super().__init__(
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=out_channels,
|
||||
use_conv=use_conv,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
up=up,
|
||||
down=down,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
self.time_stack = ResBlock(
|
||||
default(out_channels, channels),
|
||||
emb_channels,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
out_channels=default(out_channels, channels),
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=use_checkpoint,
|
||||
exchange_temb_dims=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
rearrange_pattern="b t -> b 1 t 1 1",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
emb: th.Tensor,
|
||||
num_video_frames: int,
|
||||
image_only_indicator = None,
|
||||
) -> th.Tensor:
|
||||
x = super().forward(x, emb)
|
||||
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
|
||||
x = self.time_stack(
|
||||
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
||||
)
|
||||
x = self.time_mixer(
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||
)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class Timestep(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@ -251,6 +356,15 @@ class Timestep(nn.Module):
|
||||
def forward(self, t):
|
||||
return timestep_embedding(t, self.dim)
|
||||
|
||||
def apply_control(h, control, name):
|
||||
if control is not None and name in control and len(control[name]) > 0:
|
||||
ctrl = control[name].pop()
|
||||
if ctrl is not None:
|
||||
try:
|
||||
h += ctrl
|
||||
except:
|
||||
print("warning control could not be applied", h.shape, ctrl.shape)
|
||||
return h
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
@ -259,10 +373,6 @@ class UNetModel(nn.Module):
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
@ -289,7 +399,6 @@ class UNetModel(nn.Module):
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
@ -314,6 +423,17 @@ class UNetModel(nn.Module):
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
use_temporal_resblock=False,
|
||||
use_temporal_attention=False,
|
||||
time_context_dim=None,
|
||||
extra_ff_mix_layer=False,
|
||||
use_spatial_context=False,
|
||||
merge_strategy=None,
|
||||
merge_factor=0.0,
|
||||
video_kernel_size=None,
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
device=None,
|
||||
operations=comfy.ops,
|
||||
):
|
||||
@ -341,10 +461,7 @@ class UNetModel(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
if transformer_depth_middle is None:
|
||||
transformer_depth_middle = transformer_depth[-1]
|
||||
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
@ -352,18 +469,16 @@ class UNetModel(nn.Module):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
transformer_depth = transformer_depth[:]
|
||||
transformer_depth_output = transformer_depth_output[:]
|
||||
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
@ -373,8 +488,12 @@ class UNetModel(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.use_temporal_resblocks = use_temporal_resblock
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
self.default_num_video_frames = None
|
||||
self.default_image_only_indicator = None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
@ -411,13 +530,104 @@ class UNetModel(nn.Module):
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
|
||||
def get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=1,
|
||||
context_dim=None,
|
||||
use_checkpoint=False,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
if use_temporal_attention:
|
||||
return SpatialVideoTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=depth,
|
||||
context_dim=context_dim,
|
||||
time_context_dim=time_context_dim,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
merge_strategy=merge_strategy,
|
||||
merge_factor=merge_factor,
|
||||
checkpoint=use_checkpoint,
|
||||
use_linear=use_linear_in_transformer,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
return SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
merge_factor,
|
||||
merge_strategy,
|
||||
video_kernel_size,
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels,
|
||||
dims,
|
||||
use_checkpoint,
|
||||
use_scale_shift_norm,
|
||||
down=False,
|
||||
up=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=comfy.ops
|
||||
):
|
||||
if self.use_temporal_resblocks:
|
||||
return VideoResBlock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
return ResBlock(
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_channels,
|
||||
use_checkpoint=use_checkpoint,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@ -428,7 +638,8 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
@ -443,11 +654,9 @@ class UNetModel(nn.Module):
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
layers.append(get_attention_layer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
@ -456,10 +665,13 @@ class UNetModel(nn.Module):
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@ -488,35 +700,43 @@ class UNetModel(nn.Module):
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
mid_block = [
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
)]
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [get_attention_layer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
)
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
@ -524,10 +744,13 @@ class UNetModel(nn.Module):
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch + ich,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@ -538,7 +761,8 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
num_transformers = transformer_depth_output.pop()
|
||||
if num_transformers > 0:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
@ -554,19 +778,21 @@ class UNetModel(nn.Module):
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
get_attention_layer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@ -605,9 +831,13 @@ class UNetModel(nn.Module):
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["current_index"] = 0
|
||||
transformer_options["transformer_index"] = 0
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
|
||||
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
||||
time_context = kwargs.get("time_context", None)
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
@ -622,26 +852,28 @@ class UNetModel(nn.Module):
|
||||
h = x.type(self.dtype)
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||
ctrl = control['input'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'input')
|
||||
if "input_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch"]
|
||||
for p in patch:
|
||||
h = p(h, transformer_options)
|
||||
|
||||
hs.append(h)
|
||||
if "input_block_patch_after_skip" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch_after_skip"]
|
||||
for p in patch:
|
||||
h = p(h, transformer_options)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||
ctrl = control['middle'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
||||
ctrl = control['output'].pop()
|
||||
if ctrl is not None:
|
||||
hsp += ctrl
|
||||
hsp = apply_control(hsp, control, 'output')
|
||||
|
||||
if "output_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["output_block_patch"]
|
||||
@ -654,7 +886,7 @@ class UNetModel(nn.Module):
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
|
||||
@ -13,11 +13,78 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
from einops import repeat, rearrange
|
||||
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
import comfy.ops
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
rearrange_pattern: str = "b t -> (b t) 1 1",
|
||||
):
|
||||
super().__init__()
|
||||
self.merge_strategy = merge_strategy
|
||||
self.rearrange_pattern = rearrange_pattern
|
||||
|
||||
assert (
|
||||
merge_strategy in self.strategies
|
||||
), f"merge_strategy needs to be in {self.strategies}"
|
||||
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif (
|
||||
self.merge_strategy == "learned"
|
||||
or self.merge_strategy == "learned_with_images"
|
||||
):
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
||||
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
|
||||
if self.merge_strategy == "fixed":
|
||||
# make shape compatible
|
||||
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||
alpha = self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor)
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
assert image_only_indicator is not None, "need image_only_indicator ..."
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
||||
)
|
||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_spatial,
|
||||
x_temporal,
|
||||
image_only_indicator=None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator)
|
||||
x = (
|
||||
alpha.to(x_spatial.dtype) * x_spatial
|
||||
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
@ -170,8 +237,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
|
||||
@ -83,7 +83,8 @@ def _summarize_chunk(
|
||||
)
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
torch.exp(attn_weights - max_score, out=attn_weights)
|
||||
attn_weights -= max_score
|
||||
torch.exp(attn_weights, out=attn_weights)
|
||||
exp_weights = attn_weights.to(value.dtype)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
max_score = max_score.squeeze(-1)
|
||||
|
||||
244
comfy/ldm/modules/temporal_ae.py
Normal file
244
comfy/ldm/modules/temporal_ae.py
Normal file
@ -0,0 +1,244 @@
|
||||
import functools
|
||||
from typing import Callable, Iterable, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import comfy.ops
|
||||
|
||||
from .diffusionmodules.model import (
|
||||
AttnBlock,
|
||||
Decoder,
|
||||
ResnetBlock,
|
||||
)
|
||||
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
||||
from .attention import BasicTransformerBlock
|
||||
|
||||
def partialclass(cls, *args, **kwargs):
|
||||
class NewCls(cls):
|
||||
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
|
||||
|
||||
return NewCls
|
||||
|
||||
|
||||
class VideoResBlock(ResnetBlock):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
*args,
|
||||
dropout=0.0,
|
||||
video_kernel_size=3,
|
||||
alpha=0.0,
|
||||
merge_strategy="learned",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
||||
if video_kernel_size is None:
|
||||
video_kernel_size = [3, 1, 1]
|
||||
self.time_stack = ResBlock(
|
||||
channels=out_channels,
|
||||
emb_channels=0,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=False,
|
||||
skip_t_emb=True,
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, bs):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, temb, skip_video=False, timesteps=None):
|
||||
b, c, h, w = x.shape
|
||||
if timesteps is None:
|
||||
timesteps = b
|
||||
|
||||
x = super().forward(x, temb)
|
||||
|
||||
if not skip_video:
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = self.time_stack(x, temb)
|
||||
|
||||
alpha = self.get_alpha(bs=b // timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix
|
||||
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class AE3DConv(torch.nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
||||
super().__init__(in_channels, out_channels, *args, **kwargs)
|
||||
if isinstance(video_kernel_size, Iterable):
|
||||
padding = [int(k // 2) for k in video_kernel_size]
|
||||
else:
|
||||
padding = int(video_kernel_size // 2)
|
||||
|
||||
self.time_mix_conv = torch.nn.Conv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=video_kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def forward(self, input, timesteps=None, skip_video=False):
|
||||
if timesteps is None:
|
||||
timesteps = input.shape[0]
|
||||
x = super().forward(input)
|
||||
if skip_video:
|
||||
return x
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
x = self.time_mix_conv(x)
|
||||
return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
|
||||
|
||||
class AttnVideoBlock(AttnBlock):
|
||||
def __init__(
|
||||
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
||||
):
|
||||
super().__init__(in_channels)
|
||||
# no context, single headed, as in base class
|
||||
self.time_mix_block = BasicTransformerBlock(
|
||||
dim=in_channels,
|
||||
n_heads=1,
|
||||
d_head=in_channels,
|
||||
checkpoint=False,
|
||||
ff_in=True,
|
||||
)
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.video_time_embed = torch.nn.Sequential(
|
||||
comfy.ops.Linear(self.in_channels, time_embed_dim),
|
||||
torch.nn.SiLU(),
|
||||
comfy.ops.Linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def forward(self, x, timesteps=None, skip_time_block=False):
|
||||
if skip_time_block:
|
||||
return super().forward(x)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = x.shape[0]
|
||||
|
||||
x_in = x
|
||||
x = self.attention(x)
|
||||
h, w = x.shape[2:]
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
|
||||
x_mix = x
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||
emb = self.video_time_embed(t_emb) # b, n_channels
|
||||
emb = emb[:, None, :]
|
||||
x_mix = x_mix + emb
|
||||
|
||||
alpha = self.get_alpha()
|
||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x_in + x
|
||||
|
||||
def get_alpha(
|
||||
self,
|
||||
):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
|
||||
|
||||
def make_time_attn(
|
||||
in_channels,
|
||||
attn_type="vanilla",
|
||||
attn_kwargs=None,
|
||||
alpha: float = 0,
|
||||
merge_strategy: str = "learned",
|
||||
):
|
||||
return partialclass(
|
||||
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
||||
)
|
||||
|
||||
|
||||
class Conv2DWrapper(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
class VideoDecoder(Decoder):
|
||||
available_time_modes = ["all", "conv-only", "attn-only"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
video_kernel_size: Union[int, list] = 3,
|
||||
alpha: float = 0.0,
|
||||
merge_strategy: str = "learned",
|
||||
time_mode: str = "conv-only",
|
||||
**kwargs,
|
||||
):
|
||||
self.video_kernel_size = video_kernel_size
|
||||
self.alpha = alpha
|
||||
self.merge_strategy = merge_strategy
|
||||
self.time_mode = time_mode
|
||||
assert (
|
||||
self.time_mode in self.available_time_modes
|
||||
), f"time_mode parameter has to be in {self.available_time_modes}"
|
||||
|
||||
if self.time_mode != "attn-only":
|
||||
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
||||
if self.time_mode not in ["conv-only", "only-last-conv"]:
|
||||
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
||||
if self.time_mode not in ["attn-only", "only-last-conv"]:
|
||||
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_last_layer(self, skip_time_mix=False, **kwargs):
|
||||
if self.time_mode == "attn-only":
|
||||
raise NotImplementedError("TODO")
|
||||
else:
|
||||
return (
|
||||
self.conv_out.time_mix_conv.weight
|
||||
if not skip_time_mix
|
||||
else self.conv_out.weight
|
||||
)
|
||||
@ -131,6 +131,18 @@ def load_lora(lora, to_load):
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
||||
|
||||
diff_name = "{}.diff".format(x)
|
||||
diff_weight = lora.get(diff_name, None)
|
||||
if diff_weight is not None:
|
||||
patch_dict[to_load[x]] = (diff_weight,)
|
||||
loaded_keys.add(diff_name)
|
||||
|
||||
diff_bias_name = "{}.diff_b".format(x)
|
||||
diff_bias = lora.get(diff_bias_name, None)
|
||||
if diff_bias is not None:
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
print("lora key not loaded", x)
|
||||
@ -141,9 +153,9 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
for b in range(32):
|
||||
for b in range(32): #TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
@ -154,6 +166,8 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
clip_l_present = True
|
||||
|
||||
@ -1,16 +1,37 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import comfy.model_management
|
||||
import numpy as np
|
||||
import comfy.conds
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
V_PREDICTION_EDM = 3
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
s = ModelSamplingDiscrete
|
||||
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
elif model_type == ModelType.V_PREDICTION:
|
||||
c = V_PREDICTION
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
@ -19,10 +40,12 @@ class BaseModel(torch.nn.Module):
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
@ -30,38 +53,25 @@ class BaseModel(torch.nn.Module):
|
||||
print("model_type", model_type.name)
|
||||
print("adm", self.adm_channels)
|
||||
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
|
||||
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([x] + [c_concat], dim=1)
|
||||
else:
|
||||
xc = x
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
xc = xc.to(dtype)
|
||||
t = t.to(dtype)
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
context = context.to(dtype)
|
||||
if c_adm is not None:
|
||||
c_adm = c_adm.to(dtype)
|
||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
if hasattr(extra, "to"):
|
||||
extra = extra.to(dtype)
|
||||
extra_conds[o] = extra
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
@ -72,7 +82,8 @@ class BaseModel(torch.nn.Module):
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def cond_concat(self, **kwargs):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
if self.inpaint_model:
|
||||
concat_keys = ("mask", "masked_image")
|
||||
cond_concat = []
|
||||
@ -101,8 +112,12 @@ class BaseModel(torch.nn.Module):
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
return cond_concat
|
||||
return None
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
to_load = {}
|
||||
@ -111,6 +126,7 @@ class BaseModel(torch.nn.Module):
|
||||
if k.startswith(unet_prefix):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
print("unet missing:", m)
|
||||
@ -147,6 +163,17 @@ class BaseModel(torch.nn.Module):
|
||||
def set_inpaint(self):
|
||||
self.inpaint_model = True
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
#TODO: this needs to be tweaked
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
||||
adm_inputs = []
|
||||
weights = []
|
||||
@ -241,3 +268,48 @@ class SDXL(BaseModel):
|
||||
out.append(self.embedder(torch.Tensor([target_width])))
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
class SVD_img2vid(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
fps_id = kwargs.get("fps", 6) - 1
|
||||
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([fps_id])))
|
||||
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
|
||||
out.append(self.embedder(torch.Tensor([augmentation])))
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
return flat
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
|
||||
latent_image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if latent_image is None:
|
||||
latent_image = torch.zeros_like(noise)
|
||||
|
||||
if latent_image.shape[1:] != noise.shape[1:]:
|
||||
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
|
||||
|
||||
if "time_conditioning" in kwargs:
|
||||
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
|
||||
|
||||
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
|
||||
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
||||
return out
|
||||
|
||||
@ -14,6 +14,20 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
|
||||
transformer_prefix = prefix + "1.transformer_blocks."
|
||||
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
||||
if len(transformer_keys) > 0:
|
||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
||||
return None
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
@ -40,72 +54,95 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
channel_mult = []
|
||||
attention_resolutions = []
|
||||
transformer_depth = []
|
||||
transformer_depth_output = []
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
|
||||
video_model = False
|
||||
|
||||
current_res = 1
|
||||
count = 0
|
||||
|
||||
last_res_blocks = 0
|
||||
last_transformer_depth = 0
|
||||
last_channel_mult = 0
|
||||
|
||||
while True:
|
||||
input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.')
|
||||
for count in range(input_block_count):
|
||||
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
|
||||
prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1)
|
||||
|
||||
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
|
||||
if len(block_keys) == 0:
|
||||
break
|
||||
|
||||
block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
|
||||
|
||||
if "{}0.op.weight".format(prefix) in block_keys: #new layer
|
||||
if last_transformer_depth > 0:
|
||||
attention_resolutions.append(current_res)
|
||||
transformer_depth.append(last_transformer_depth)
|
||||
num_res_blocks.append(last_res_blocks)
|
||||
channel_mult.append(last_channel_mult)
|
||||
|
||||
current_res *= 2
|
||||
last_res_blocks = 0
|
||||
last_transformer_depth = 0
|
||||
last_channel_mult = 0
|
||||
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
transformer_depth_output.append(out[0])
|
||||
else:
|
||||
transformer_depth_output.append(0)
|
||||
else:
|
||||
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
|
||||
if res_block_prefix in block_keys:
|
||||
last_res_blocks += 1
|
||||
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
|
||||
|
||||
transformer_prefix = prefix + "1.transformer_blocks."
|
||||
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
||||
if len(transformer_keys) > 0:
|
||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||
if context_dim is None:
|
||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
transformer_depth.append(out[0])
|
||||
if context_dim is None:
|
||||
context_dim = out[1]
|
||||
use_linear_in_transformer = out[2]
|
||||
video_model = out[3]
|
||||
else:
|
||||
transformer_depth.append(0)
|
||||
|
||||
res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
|
||||
if res_block_prefix in block_keys_output:
|
||||
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
transformer_depth_output.append(out[0])
|
||||
else:
|
||||
transformer_depth_output.append(0)
|
||||
|
||||
count += 1
|
||||
|
||||
if last_transformer_depth > 0:
|
||||
attention_resolutions.append(current_res)
|
||||
transformer_depth.append(last_transformer_depth)
|
||||
num_res_blocks.append(last_res_blocks)
|
||||
channel_mult.append(last_channel_mult)
|
||||
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
||||
|
||||
if len(set(num_res_blocks)) == 1:
|
||||
num_res_blocks = num_res_blocks[0]
|
||||
|
||||
if len(set(transformer_depth)) == 1:
|
||||
transformer_depth = transformer_depth[0]
|
||||
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
|
||||
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
||||
else:
|
||||
transformer_depth_middle = -1
|
||||
|
||||
unet_config["in_channels"] = in_channels
|
||||
unet_config["model_channels"] = model_channels
|
||||
unet_config["num_res_blocks"] = num_res_blocks
|
||||
unet_config["attention_resolutions"] = attention_resolutions
|
||||
unet_config["transformer_depth"] = transformer_depth
|
||||
unet_config["transformer_depth_output"] = transformer_depth_output
|
||||
unet_config["channel_mult"] = channel_mult
|
||||
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
||||
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
|
||||
unet_config["context_dim"] = context_dim
|
||||
|
||||
if video_model:
|
||||
unet_config["extra_ff_mix_layer"] = True
|
||||
unet_config["use_spatial_context"] = True
|
||||
unet_config["merge_strategy"] = "learned_with_images"
|
||||
unet_config["merge_factor"] = 0.0
|
||||
unet_config["video_kernel_size"] = [3, 1, 1]
|
||||
unet_config["use_temporal_resblock"] = True
|
||||
unet_config["use_temporal_attention"] = True
|
||||
else:
|
||||
unet_config["use_temporal_resblock"] = False
|
||||
unet_config["use_temporal_attention"] = False
|
||||
|
||||
return unet_config
|
||||
|
||||
def model_config_from_unet_config(unet_config):
|
||||
@ -124,19 +161,65 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
|
||||
else:
|
||||
return model_config
|
||||
|
||||
def convert_config(unet_config):
|
||||
new_config = unet_config.copy()
|
||||
num_res_blocks = new_config.get("num_res_blocks", None)
|
||||
channel_mult = new_config.get("channel_mult", None)
|
||||
|
||||
if isinstance(num_res_blocks, int):
|
||||
num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
|
||||
if "attention_resolutions" in new_config:
|
||||
attention_resolutions = new_config.pop("attention_resolutions")
|
||||
transformer_depth = new_config.get("transformer_depth", None)
|
||||
transformer_depth_middle = new_config.get("transformer_depth_middle", None)
|
||||
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
if transformer_depth_middle is None:
|
||||
transformer_depth_middle = transformer_depth[-1]
|
||||
t_in = []
|
||||
t_out = []
|
||||
s = 1
|
||||
for i in range(len(num_res_blocks)):
|
||||
res = num_res_blocks[i]
|
||||
d = 0
|
||||
if s in attention_resolutions:
|
||||
d = transformer_depth[i]
|
||||
|
||||
t_in += [d] * res
|
||||
t_out += [d] * (res + 1)
|
||||
s *= 2
|
||||
transformer_depth = t_in
|
||||
transformer_depth_output = t_out
|
||||
new_config["transformer_depth"] = t_in
|
||||
new_config["transformer_depth_output"] = t_out
|
||||
new_config["transformer_depth_middle"] = transformer_depth_middle
|
||||
|
||||
new_config["num_res_blocks"] = num_res_blocks
|
||||
return new_config
|
||||
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
match = {}
|
||||
attention_resolutions = []
|
||||
transformer_depth = []
|
||||
|
||||
attn_res = 1
|
||||
for i in range(5):
|
||||
k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i)
|
||||
if k in state_dict:
|
||||
match["context_dim"] = state_dict[k].shape[1]
|
||||
attention_resolutions.append(attn_res)
|
||||
attn_res *= 2
|
||||
down_blocks = count_blocks(state_dict, "down_blocks.{}")
|
||||
for i in range(down_blocks):
|
||||
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
||||
for ab in range(attn_blocks):
|
||||
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
||||
transformer_depth.append(transformer_count)
|
||||
if transformer_count > 0:
|
||||
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
|
||||
|
||||
match["attention_resolutions"] = attention_resolutions
|
||||
attn_res *= 2
|
||||
if attn_blocks == 0:
|
||||
transformer_depth.append(0)
|
||||
transformer_depth.append(0)
|
||||
|
||||
match["transformer_depth"] = transformer_depth
|
||||
|
||||
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||
@ -148,50 +231,65 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
|
||||
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
|
||||
'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
|
||||
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
|
||||
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
|
||||
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint]
|
||||
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
@ -200,7 +298,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
matches = False
|
||||
break
|
||||
if matches:
|
||||
return unet_config
|
||||
return convert_config(unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||
|
||||
@ -133,6 +133,10 @@ else:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
try:
|
||||
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
XFORMERS_VERSION = xformers.version.__version__
|
||||
print("xformers version:", XFORMERS_VERSION)
|
||||
@ -478,6 +482,21 @@ def text_encoder_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_dtype(device=None):
|
||||
if args.fp8_e4m3fn_text_enc:
|
||||
return torch.float8_e4m3fn
|
||||
elif args.fp8_e5m2_text_enc:
|
||||
return torch.float8_e5m2
|
||||
elif args.fp16_text_enc:
|
||||
return torch.float16
|
||||
elif args.fp32_text_enc:
|
||||
return torch.float32
|
||||
|
||||
if should_use_fp16(device, prioritize_performance=False):
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def vae_device():
|
||||
return get_torch_device()
|
||||
|
||||
@ -579,27 +598,6 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
else:
|
||||
return mem_free_total
|
||||
|
||||
def batch_area_memory(area):
|
||||
if xformers_enabled() or pytorch_attention_flash_attention():
|
||||
#TODO: these formulas are copied from maximum_batch_area below
|
||||
return (area / 20) * (1024 * 1024)
|
||||
else:
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
def maximum_batch_area():
|
||||
global vram_state
|
||||
if vram_state == VRAMState.NO_VRAM:
|
||||
return 0
|
||||
|
||||
memory_free = get_free_memory() / (1024 * 1024)
|
||||
if xformers_enabled() or pytorch_attention_flash_attention():
|
||||
#TODO: this needs to be tweaked
|
||||
area = 20 * memory_free
|
||||
else:
|
||||
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
|
||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||
return int(max(area, 0))
|
||||
|
||||
def cpu_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.CPU
|
||||
|
||||
@ -6,11 +6,13 @@ import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
@ -20,6 +22,8 @@ class ModelPatcher:
|
||||
else:
|
||||
self.current_device = current_device
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@ -33,11 +37,12 @@ class ModelPatcher:
|
||||
return size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
return n
|
||||
@ -47,6 +52,9 @@ class ModelPatcher:
|
||||
return True
|
||||
return False
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
@ -88,9 +96,18 @@ class ModelPatcher:
|
||||
def set_model_attn2_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def set_model_input_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "input_block_patch")
|
||||
|
||||
def set_model_input_block_patch_after_skip(self, patch):
|
||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||
|
||||
def set_model_output_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
@ -107,10 +124,10 @@ class ModelPatcher:
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
if "unet_wrapper_function" in self.model_options:
|
||||
wrap_func = self.model_options["unet_wrapper_function"]
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, "to"):
|
||||
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
|
||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
@ -128,6 +145,7 @@ class ModelPatcher:
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
comfy.model_management.unload_model_clones(self)
|
||||
model_sd = self.model_state_dict()
|
||||
p = {}
|
||||
for k in model_sd:
|
||||
@ -150,6 +168,12 @@ class ModelPatcher:
|
||||
return sd
|
||||
|
||||
def patch_model(self, device_to=None):
|
||||
for k in self.object_patches:
|
||||
old = getattr(self.model, k)
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
setattr(self.model, k, self.object_patches[k])
|
||||
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
@ -158,15 +182,20 @@ class ModelPatcher:
|
||||
|
||||
weight = model_sd[key]
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(self.offload_device)
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
comfy.utils.set_attr(self.model, key, out_weight)
|
||||
if inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
|
||||
if device_to is not None:
|
||||
@ -282,11 +311,21 @@ class ModelPatcher:
|
||||
def unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
for k in keys:
|
||||
comfy.utils.set_attr(self.model, k, self.backup[k])
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
comfy.utils.set_attr(self.model, k, self.backup[k])
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
setattr(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup = {}
|
||||
|
||||
129
comfy/model_sampling.py
Normal file
129
comfy/model_sampling.py
Normal file
@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
import math
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
|
||||
class V_PREDICTION(EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
beta_schedule = "linear"
|
||||
if model_config is not None:
|
||||
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
||||
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
|
||||
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
|
||||
|
||||
def sigma(self, timestep):
|
||||
t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||
|
||||
|
||||
class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
self.sigma_data = 1.0
|
||||
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
sigma_min = sampling_settings.get("sigma_min", 0.002)
|
||||
sigma_max = sampling_settings.get("sigma_max", 120.0)
|
||||
self.set_sigma_range(sigma_min, sigma_max)
|
||||
|
||||
def set_sigma_range(self, sigma_min, sigma_max):
|
||||
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
|
||||
|
||||
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return 0.25 * sigma.log()
|
||||
|
||||
def sigma(self, timestep):
|
||||
return (timestep / 0.25).exp()
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
percent = 1.0 - percent
|
||||
|
||||
log_sigma_min = math.log(self.sigma_min)
|
||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||
24
comfy/ops.py
24
comfy/ops.py
@ -1,29 +1,23 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, input):
|
||||
return torch.nn.functional.linear(input, self.weight, self.bias)
|
||||
class Linear(torch.nn.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return Conv3d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import comfy.conds
|
||||
import comfy.utils
|
||||
import math
|
||||
import numpy as np
|
||||
@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device):
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
def broadcast_cond(cond, batch, device):
|
||||
"""broadcasts conditioning to the batch size"""
|
||||
copy = []
|
||||
for p in cond:
|
||||
t = comfy.utils.repeat_to_batch_size(p[0], batch)
|
||||
t = t.to(device)
|
||||
copy += [[t] + p[1:]]
|
||||
return copy
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c[1]:
|
||||
models += [c[1][model_type]]
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
temp = c[1].copy()
|
||||
model_conds = temp.get("model_conds", {})
|
||||
if c[0] is not None:
|
||||
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0])
|
||||
temp["model_conds"] = model_conds
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(positive, negative, dtype):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
||||
@ -72,18 +75,18 @@ def cleanup_additional_models(models):
|
||||
|
||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||
device = model.load_device
|
||||
positive = convert_cond(positive)
|
||||
negative = convert_cond(negative)
|
||||
|
||||
if noise_mask is not None:
|
||||
noise_mask = prepare_mask(noise_mask, noise_shape, device)
|
||||
|
||||
real_model = None
|
||||
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
|
||||
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
|
||||
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
positive_copy = broadcast_cond(positive, noise_shape[0], device)
|
||||
negative_copy = broadcast_cond(negative, noise_shape[0], device)
|
||||
return real_model, positive_copy, negative_copy, noise_mask, models
|
||||
return real_model, positive, negative, noise_mask, models
|
||||
|
||||
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
|
||||
@ -1,48 +1,42 @@
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .k_diffusion import external as k_diffusion_external
|
||||
from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import enum
|
||||
from comfy import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
import math
|
||||
from comfy import model_base
|
||||
import comfy.utils
|
||||
import comfy.conds
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(cond, x_in, timestep_in):
|
||||
#Returns denoised
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
if 'timestep_start' in cond[1]:
|
||||
timestep_start = cond[1]['timestep_start']
|
||||
|
||||
if 'timestep_start' in conds:
|
||||
timestep_start = conds['timestep_start']
|
||||
if timestep_in[0] > timestep_start:
|
||||
return None
|
||||
if 'timestep_end' in cond[1]:
|
||||
timestep_end = cond[1]['timestep_end']
|
||||
if 'timestep_end' in conds:
|
||||
timestep_end = conds['timestep_end']
|
||||
if timestep_in[0] < timestep_end:
|
||||
return None
|
||||
if 'area' in cond[1]:
|
||||
area = cond[1]['area']
|
||||
if 'strength' in cond[1]:
|
||||
strength = cond[1]['strength']
|
||||
|
||||
adm_cond = None
|
||||
if 'adm_encoded' in cond[1]:
|
||||
adm_cond = cond[1]['adm_encoded']
|
||||
if 'area' in conds:
|
||||
area = conds['area']
|
||||
if 'strength' in conds:
|
||||
strength = conds['strength']
|
||||
|
||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
if 'mask' in cond[1]:
|
||||
if 'mask' in conds:
|
||||
# Scale the mask to the size of the input
|
||||
# The mask should have been resized as we began the sampling process
|
||||
mask_strength = 1.0
|
||||
if "mask_strength" in cond[1]:
|
||||
mask_strength = cond[1]["mask_strength"]
|
||||
mask = cond[1]['mask']
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert(mask.shape[1] == x_in.shape[2])
|
||||
assert(mask.shape[2] == x_in.shape[3])
|
||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||
@ -51,7 +45,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
|
||||
if 'mask' not in cond[1]:
|
||||
if 'mask' not in conds:
|
||||
rr = 8
|
||||
if area[2] != 0:
|
||||
for t in range(rr):
|
||||
@ -67,27 +61,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
|
||||
conditionning = {}
|
||||
conditionning['c_crossattn'] = cond[0]
|
||||
|
||||
if 'concat' in cond[1]:
|
||||
cond_concat_in = cond[1]['concat']
|
||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||
cropped = []
|
||||
for x in cond_concat_in:
|
||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
cropped.append(cr)
|
||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||
|
||||
if adm_cond is not None:
|
||||
conditionning['c_adm'] = adm_cond
|
||||
model_conds = conds["model_conds"]
|
||||
for c in model_conds:
|
||||
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
|
||||
control = None
|
||||
if 'control' in cond[1]:
|
||||
control = cond[1]['control']
|
||||
if 'control' in conds:
|
||||
control = conds['control']
|
||||
|
||||
patches = None
|
||||
if 'gligen' in cond[1]:
|
||||
gligen = cond[1]['gligen']
|
||||
if 'gligen' in conds:
|
||||
gligen = conds['gligen']
|
||||
patches = {}
|
||||
gligen_type = gligen[0]
|
||||
gligen_model = gligen[1]
|
||||
@ -105,22 +89,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
return True
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
if 'c_crossattn' in c1:
|
||||
s1 = c1['c_crossattn'].shape
|
||||
s2 = c2['c_crossattn'].shape
|
||||
if s1 != s2:
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if 'c_concat' in c1:
|
||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||
return False
|
||||
if 'c_adm' in c1:
|
||||
if c1['c_adm'].shape != c2['c_adm'].shape:
|
||||
for k in c1:
|
||||
if not c1[k].can_concat(c2[k]):
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -149,39 +119,27 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
c_concat = []
|
||||
c_adm = []
|
||||
crossattn_max_len = 0
|
||||
for x in c_list:
|
||||
if 'c_crossattn' in x:
|
||||
c = x['c_crossattn']
|
||||
if crossattn_max_len == 0:
|
||||
crossattn_max_len = c.shape[1]
|
||||
else:
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
c_crossattn.append(c)
|
||||
if 'c_concat' in x:
|
||||
c_concat.append(x['c_concat'])
|
||||
if 'c_adm' in x:
|
||||
c_adm.append(x['c_adm'])
|
||||
out = {}
|
||||
c_crossattn_out = []
|
||||
for c in c_crossattn:
|
||||
if c.shape[1] < crossattn_max_len:
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
c_crossattn_out.append(c)
|
||||
|
||||
if len(c_crossattn_out) > 0:
|
||||
out['c_crossattn'] = torch.cat(c_crossattn_out)
|
||||
if len(c_concat) > 0:
|
||||
out['c_concat'] = torch.cat(c_concat)
|
||||
if len(c_adm) > 0:
|
||||
out['c_adm'] = torch.cat(c_adm)
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
cur = temp.get(k, [])
|
||||
cur.append(x[k])
|
||||
temp[k] = cur
|
||||
|
||||
out = {}
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in)/100000.0
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
out_uncond = torch.zeros_like(x_in)
|
||||
out_uncond_count = torch.ones_like(x_in)/100000.0
|
||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
@ -212,9 +170,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area):
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
@ -260,12 +220,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
@ -281,39 +243,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
del out_count
|
||||
out_uncond /= out_uncond_count
|
||||
del out_uncond_count
|
||||
|
||||
return out_cond, out_uncond
|
||||
|
||||
|
||||
max_total_area = model_management.maximum_batch_area()
|
||||
if math.isclose(cond_scale, 1.0):
|
||||
uncond = None
|
||||
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
||||
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
||||
return model_options["sampler_cfg_function"](args)
|
||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||
return x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond, **kwargs)
|
||||
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
return out
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
@ -332,32 +283,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
||||
return out
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
sigs = []
|
||||
ss = len(model.sigmas) / steps
|
||||
ss = len(s.sigmas) / steps
|
||||
for x in range(steps):
|
||||
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
|
||||
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def ddim_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
sigs = []
|
||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
|
||||
for x in range(len(ddim_timesteps) - 1, -1, -1):
|
||||
ts = ddim_timesteps[x]
|
||||
if ts > 999:
|
||||
ts = 999
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
||||
ss = len(s.sigmas) // steps
|
||||
x = 1
|
||||
while x < len(s.sigmas):
|
||||
sigs += [float(s.sigmas[x])]
|
||||
x += ss
|
||||
sigs = sigs[::-1]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def sgm_scheduler(model, steps):
|
||||
def normal_scheduler(model, steps, sgm=False, floor=False):
|
||||
s = model.model_sampling
|
||||
start = s.timestep(s.sigma_max)
|
||||
end = s.timestep(s.sigma_min)
|
||||
|
||||
if sgm:
|
||||
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
||||
else:
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
|
||||
sigs = []
|
||||
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
if ts > 999:
|
||||
ts = 999
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
||||
sigs.append(s.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
@ -389,19 +348,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
||||
for i in range(len(conditions)):
|
||||
c = conditions[i]
|
||||
if 'area' in c[1]:
|
||||
area = c[1]['area']
|
||||
if 'area' in c:
|
||||
area = c['area']
|
||||
if area[0] == "percentage":
|
||||
modified = c[1].copy()
|
||||
modified = c.copy()
|
||||
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
|
||||
modified['area'] = area
|
||||
c = [c[0], modified]
|
||||
c = modified
|
||||
conditions[i] = c
|
||||
|
||||
if 'mask' in c[1]:
|
||||
mask = c[1]['mask']
|
||||
if 'mask' in c:
|
||||
mask = c['mask']
|
||||
mask = mask.to(device=device)
|
||||
modified = c[1].copy()
|
||||
modified = c.copy()
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1] != h or mask.shape[2] != w:
|
||||
@ -422,66 +381,70 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||
modified['area'] = area
|
||||
|
||||
modified['mask'] = mask
|
||||
conditions[i] = [c[0], modified]
|
||||
conditions[i] = modified
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c):
|
||||
if 'area' not in c[1]:
|
||||
if 'area' not in c:
|
||||
return
|
||||
|
||||
c_area = c[1]['area']
|
||||
c_area = c['area']
|
||||
smallest = None
|
||||
for x in conds:
|
||||
if 'area' in x[1]:
|
||||
a = x[1]['area']
|
||||
if 'area' in x:
|
||||
a = x['area']
|
||||
if c_area[2] >= a[2] and c_area[3] >= a[3]:
|
||||
if a[0] + a[2] >= c_area[0] + c_area[2]:
|
||||
if a[1] + a[3] >= c_area[1] + c_area[3]:
|
||||
if smallest is None:
|
||||
smallest = x
|
||||
elif 'area' not in smallest[1]:
|
||||
elif 'area' not in smallest:
|
||||
smallest = x
|
||||
else:
|
||||
if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]:
|
||||
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
|
||||
smallest = x
|
||||
else:
|
||||
if smallest is None:
|
||||
smallest = x
|
||||
if smallest is None:
|
||||
return
|
||||
if 'area' in smallest[1]:
|
||||
if smallest[1]['area'] == c_area:
|
||||
if 'area' in smallest:
|
||||
if smallest['area'] == c_area:
|
||||
return
|
||||
n = c[1].copy()
|
||||
conds += [[smallest[0], n]]
|
||||
|
||||
out = c.copy()
|
||||
out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
|
||||
conds += [out]
|
||||
|
||||
def calculate_start_end_timesteps(model, conds):
|
||||
s = model.model_sampling
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
if 'start_percent' in x[1]:
|
||||
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
|
||||
if 'end_percent' in x[1]:
|
||||
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
|
||||
if 'start_percent' in x:
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
|
||||
if (timestep_start is not None) or (timestep_end is not None):
|
||||
n = x[1].copy()
|
||||
n = x.copy()
|
||||
if (timestep_start is not None):
|
||||
n['timestep_start'] = timestep_start
|
||||
if (timestep_end is not None):
|
||||
n['timestep_end'] = timestep_end
|
||||
conds[t] = [x[0], n]
|
||||
conds[t] = n
|
||||
|
||||
def pre_run_control(model, conds):
|
||||
s = model.model_sampling
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
|
||||
if 'control' in x[1]:
|
||||
x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@ -490,16 +453,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
uncond_other = []
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
if 'area' not in x[1]:
|
||||
if name in x[1] and x[1][name] is not None:
|
||||
cond_cnets.append(x[1][name])
|
||||
if 'area' not in x:
|
||||
if name in x and x[name] is not None:
|
||||
cond_cnets.append(x[name])
|
||||
else:
|
||||
cond_other.append((x, t))
|
||||
for t in range(len(uncond)):
|
||||
x = uncond[t]
|
||||
if 'area' not in x[1]:
|
||||
if name in x[1] and x[1][name] is not None:
|
||||
uncond_cnets.append(x[1][name])
|
||||
if 'area' not in x:
|
||||
if name in x and x[name] is not None:
|
||||
uncond_cnets.append(x[name])
|
||||
else:
|
||||
uncond_other.append((x, t))
|
||||
|
||||
@ -509,47 +472,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
for x in range(len(cond_cnets)):
|
||||
temp = uncond_other[x % len(uncond_other)]
|
||||
o = temp[0]
|
||||
if name in o[1] and o[1][name] is not None:
|
||||
n = o[1].copy()
|
||||
if name in o and o[name] is not None:
|
||||
n = o.copy()
|
||||
n[name] = uncond_fill_func(cond_cnets, x)
|
||||
uncond += [[o[0], n]]
|
||||
uncond += [n]
|
||||
else:
|
||||
n = o[1].copy()
|
||||
n = o.copy()
|
||||
n[name] = uncond_fill_func(cond_cnets, x)
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
uncond[temp[1]] = n
|
||||
|
||||
def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
||||
def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
adm_out = None
|
||||
if 'adm' in x[1]:
|
||||
adm_out = x[1]["adm"]
|
||||
else:
|
||||
params = x[1].copy()
|
||||
params["width"] = params.get("width", width * 8)
|
||||
params["height"] = params.get("height", height * 8)
|
||||
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
||||
adm_out = model.encode_adm(device=device, **params)
|
||||
|
||||
if adm_out is not None:
|
||||
x[1] = x[1].copy()
|
||||
x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device)
|
||||
|
||||
return conds
|
||||
|
||||
def encode_cond(model_function, key, conds, device, **kwargs):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
params = x[1].copy()
|
||||
params = x.copy()
|
||||
params["device"] = device
|
||||
params["noise"] = noise
|
||||
params["width"] = params.get("width", noise.shape[3] * 8)
|
||||
params["height"] = params.get("height", noise.shape[2] * 8)
|
||||
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
||||
for k in kwargs:
|
||||
if k not in params:
|
||||
params[k] = kwargs[k]
|
||||
|
||||
out = model_function(**params)
|
||||
if out is not None:
|
||||
x[1] = x[1].copy()
|
||||
x[1][key] = out
|
||||
x = x.copy()
|
||||
model_conds = x['model_conds'].copy()
|
||||
for k in out:
|
||||
model_conds[k] = out[k]
|
||||
x['model_conds'] = model_conds
|
||||
conds[t] = x
|
||||
return conds
|
||||
|
||||
class Sampler:
|
||||
@ -557,95 +508,79 @@ class Sampler:
|
||||
pass
|
||||
|
||||
def max_denoise(self, model_wrap, sigmas):
|
||||
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
|
||||
|
||||
class DDIM(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
total_steps = len(timesteps) - 1
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||
|
||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||
|
||||
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
|
||||
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=model_wrap.predict_eps_discrete_timestep,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1,
|
||||
disable_pbar=disable_pbar)
|
||||
return samples
|
||||
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
class UNIPC(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
|
||||
class UNIPCBH2(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||
|
||||
def ksampler(sampler_name, extra_options={}):
|
||||
class KSAMPLER(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
model_k = KSamplerX0Inpaint(model_wrap)
|
||||
model_k.latent_image = latent_image
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
self.sampler_function = sampler_function
|
||||
self.extra_options = extra_options
|
||||
self.inpaint_options = inpaint_options
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
model_k = KSamplerX0Inpaint(model_wrap)
|
||||
model_k.latent_image = latent_image
|
||||
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
if self.max_denoise(model_wrap, sigmas):
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
noise = noise * sigmas[0]
|
||||
if self.max_denoise(model_wrap, sigmas):
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
|
||||
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
||||
return samples
|
||||
|
||||
|
||||
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
if sampler_name == "dpm_fast":
|
||||
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
|
||||
sigma_min = sigmas[-1]
|
||||
if sigma_min == 0:
|
||||
sigma_min = sigmas[-2]
|
||||
total_steps = len(sigmas) - 1
|
||||
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
|
||||
sampler_function = dpm_fast_function
|
||||
elif sampler_name == "dpm_adaptive":
|
||||
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
|
||||
sigma_min = sigmas[-1]
|
||||
if sigma_min == 0:
|
||||
sigma_min = sigmas[-2]
|
||||
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
|
||||
sampler_function = dpm_adaptive_function
|
||||
else:
|
||||
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if sampler_name == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
elif sampler_name == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options)
|
||||
return samples
|
||||
return KSAMPLER
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
def wrap_model(model):
|
||||
model_denoise = CFGNoisePredictor(model)
|
||||
if model.model_type == model_base.ModelType.V_PREDICTION:
|
||||
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
|
||||
else:
|
||||
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
|
||||
return model_wrap
|
||||
return model_denoise
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
positive = positive[:]
|
||||
@ -656,8 +591,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
|
||||
model_wrap = wrap_model(model)
|
||||
|
||||
calculate_start_end_timesteps(model_wrap, negative)
|
||||
calculate_start_end_timesteps(model_wrap, positive)
|
||||
calculate_start_end_timesteps(model, negative)
|
||||
calculate_start_end_timesteps(model, positive)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
@ -665,21 +600,17 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
pre_run_control(model_wrap, negative + positive)
|
||||
pre_run_control(model, negative + positive)
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if latent_image is not None:
|
||||
latent_image = model.process_latent_in(latent_image)
|
||||
|
||||
if model.is_adm():
|
||||
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
||||
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
||||
|
||||
if hasattr(model, 'cond_concat'):
|
||||
positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
if hasattr(model, 'extra_conds'):
|
||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||
|
||||
@ -690,30 +621,29 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
model_wrap = wrap_model(model)
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = model_wrap.get_sigmas(steps)
|
||||
sigmas = normal_scheduler(model, steps)
|
||||
elif scheduler_name == "simple":
|
||||
sigmas = simple_scheduler(model_wrap, steps)
|
||||
sigmas = simple_scheduler(model, steps)
|
||||
elif scheduler_name == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(model_wrap, steps)
|
||||
sigmas = ddim_scheduler(model, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = sgm_scheduler(model_wrap, steps)
|
||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
return sigmas
|
||||
|
||||
def sampler_class(name):
|
||||
def sampler_object(name):
|
||||
if name == "uni_pc":
|
||||
sampler = UNIPC
|
||||
sampler = UNIPC()
|
||||
elif name == "uni_pc_bh2":
|
||||
sampler = UNIPCBH2
|
||||
sampler = UNIPCBH2()
|
||||
elif name == "ddim":
|
||||
sampler = DDIM
|
||||
sampler = ksampler("euler", inpaint_options={"random": True})
|
||||
else:
|
||||
sampler = ksampler(name)
|
||||
return sampler
|
||||
@ -776,6 +706,6 @@ class KSampler:
|
||||
else:
|
||||
return torch.zeros_like(noise)
|
||||
|
||||
sampler = sampler_class(self.sampler)
|
||||
sampler = sampler_object(self.sampler)
|
||||
|
||||
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
|
||||
99
comfy/sd.py
99
comfy/sd.py
@ -23,6 +23,7 @@ import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.supported_models_base
|
||||
import comfy.taesd.taesd
|
||||
|
||||
def load_model_weights(model, sd):
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
@ -55,13 +56,26 @@ def load_clip_weights(model, sd):
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model)
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
k = set(k)
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
@ -82,10 +96,7 @@ class CLIP:
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
if model_management.should_use_fp16(load_device, prioritize_performance=False):
|
||||
params['dtype'] = torch.float16
|
||||
else:
|
||||
params['dtype'] = torch.float32
|
||||
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
@ -144,10 +155,24 @@ class VAE:
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
|
||||
if config is None:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
decoder_config = encoder_config.copy()
|
||||
decoder_config["video_kernel_size"] = [3, 1, 1]
|
||||
decoder_config["alpha"] = 0.0
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.first_stage_model = comfy.taesd.taesd.TAESD()
|
||||
else:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
@ -162,10 +187,12 @@ class VAE:
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
self.offload_device = model_management.vae_offload_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.vae_dtype = model_management.vae_dtype()
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
@ -194,10 +221,9 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def decode(self, samples_in):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
try:
|
||||
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
|
||||
model_management.free_memory(memory_used, self.device)
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
@ -210,22 +236,19 @@ class VAE:
|
||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
||||
model_management.free_memory(memory_used, self.device)
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
@ -238,14 +261,12 @@ class VAE:
|
||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
@ -360,7 +381,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
from . import latent_formats
|
||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||
model_config.unet_config = unet_config
|
||||
model_config.unet_config = model_detection.convert_config(unet_config)
|
||||
|
||||
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||
@ -388,11 +409,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model.clip_h
|
||||
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model.clip_l
|
||||
load_clip_weights(w, state_dict)
|
||||
|
||||
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
@ -429,6 +452,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae = VAE(sd=vae_sd)
|
||||
|
||||
if output_clip:
|
||||
@ -453,20 +477,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
if "input_blocks.0.0.weight" in sd: #ldm
|
||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return None
|
||||
new_sd = sd
|
||||
|
||||
else: #diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||
if model_config is None:
|
||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||
return None
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
|
||||
@ -481,8 +503,19 @@ def load_unet(unet_path): #load unet in diffusers format
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
print("left over keys in unet:", left_over)
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||
|
||||
def load_unet(unet_path):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd)
|
||||
if model is None:
|
||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||
model_management.load_models_gpu([model, clip.load_model()])
|
||||
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||
|
||||
@ -8,34 +8,56 @@ import zipfile
|
||||
from . import model_management
|
||||
import contextlib
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
end_token = special_tokens.get("end", None)
|
||||
pad_token = special_tokens.get("pad")
|
||||
output = []
|
||||
if start_token is not None:
|
||||
output.append(start_token)
|
||||
if end_token is not None:
|
||||
output.append(end_token)
|
||||
output += [pad_token] * (length - len(output))
|
||||
return output
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list(self.empty_tokens)
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
for x in token_weight_pairs:
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
max_token_len = max(len(tokens), max_token_len)
|
||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
z_empty = out[0:1]
|
||||
if pooled.shape[0] > 1:
|
||||
first_pooled = pooled[1:2]
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
else:
|
||||
first_pooled = pooled[0:1]
|
||||
first_pooled = pooled
|
||||
|
||||
output = []
|
||||
for k in range(1, out.shape[0]):
|
||||
for k in range(0, sections):
|
||||
z = out[k:k+1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k - 1][j][1]
|
||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k][j][1]
|
||||
if weight != 1.0:
|
||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return z_empty.cpu(), first_pooled.cpu()
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||
return out[-1:].cpu(), first_pooled
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
|
||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
@ -43,37 +65,43 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
|
||||
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.num_layers = 12
|
||||
if textmodel_path is not None:
|
||||
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
|
||||
self.transformer = model_class.from_pretrained(textmodel_path)
|
||||
else:
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||
config = config_class.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
with comfy.ops.use_comfy_ops(device, dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = CLIPTextModel(config)
|
||||
self.transformer = model_class(config)
|
||||
|
||||
self.inner_name = inner_name
|
||||
if dtype is not None:
|
||||
self.transformer.to(dtype)
|
||||
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
|
||||
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
|
||||
inner_model = getattr(self.transformer, self.inner_name)
|
||||
if hasattr(inner_model, "embeddings"):
|
||||
inner_model.embeddings.to(torch.float32)
|
||||
else:
|
||||
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
|
||||
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.empty_tokens = [[49406] + [49407] * 76]
|
||||
self.special_tokens = special_tokens
|
||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.enable_attention_masks = False
|
||||
|
||||
self.layer_norm_hidden_state = True
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) <= self.num_layers
|
||||
@ -117,7 +145,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
|
||||
while len(tokens_temp) < len(x):
|
||||
tokens_temp += [self.empty_tokens[0][-1]]
|
||||
tokens_temp += [self.special_tokens["pad"]]
|
||||
out_tokens += [tokens_temp]
|
||||
|
||||
n = token_dict_size
|
||||
@ -142,12 +170,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
|
||||
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(device), torch.float32):
|
||||
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
@ -168,12 +196,16 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
|
||||
|
||||
pooled_output = outputs.pooler_output
|
||||
if self.text_projection is not None:
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
pooled_output = outputs.pooler_output.float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
return z.float(), pooled_output.float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
@ -278,7 +310,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
|
||||
valid_file = None
|
||||
for embed_dir in embedding_directory:
|
||||
embed_path = os.path.join(embed_dir, embedding_name)
|
||||
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
|
||||
embed_dir = os.path.abspath(embed_dir)
|
||||
try:
|
||||
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
|
||||
continue
|
||||
except:
|
||||
continue
|
||||
if not os.path.isfile(embed_path):
|
||||
extensions = ['.safetensors', '.pt', '.bin']
|
||||
for x in extensions:
|
||||
@ -336,18 +374,25 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
self.max_length = max_length
|
||||
self.max_tokens_per_section = self.max_length - 2
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.embedding_directory = embedding_directory
|
||||
@ -408,11 +453,13 @@ class SD1Tokenizer:
|
||||
else:
|
||||
continue
|
||||
#parse word
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
|
||||
|
||||
#reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
batch = [(self.start_token, 1.0, 0)]
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
#determine if we're going to try and keep the tokens in a single batch
|
||||
@ -429,16 +476,21 @@ class SD1Tokenizer:
|
||||
#add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
#start new batch
|
||||
batch = [(self.start_token, 1.0, 0)]
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
else:
|
||||
batch.extend([(t,w,i+1) for t,w in t_group])
|
||||
t_group = []
|
||||
|
||||
#fill last batch
|
||||
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
@ -448,3 +500,40 @@ class SD1Tokenizer:
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs):
|
||||
super().__init__()
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
getattr(self, self.clip).clip_layer(layer_idx)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
getattr(self, self.clip).reset_clip_layer()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||
return out, pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
||||
@ -2,16 +2,23 @@ from comfy import sd1_clip
|
||||
import torch
|
||||
import os
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=23
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||
|
||||
@ -2,28 +2,27 @@ from comfy import sd1_clip
|
||||
import torch
|
||||
import os
|
||||
|
||||
class SDXLClipG(sd1_clip.SD1ClipModel):
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
self.layer_norm_hidden_state = False
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
|
||||
class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None):
|
||||
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@ -38,8 +37,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||
self.clip_l.layer_norm_hidden_state = False
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
@ -63,21 +61,6 @@ class SDXLClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.clip_l.load_sd(sd)
|
||||
|
||||
class SDXLRefinerClipModel(torch.nn.Module):
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
self.clip_g.clip_layer(layer_idx)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
self.clip_g.reset_clip_layer()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
return g_out, g_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.clip_g.load_sd(sd)
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||
|
||||
@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": False,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
@ -38,8 +39,15 @@ class SD15(supported_models_base.BASE):
|
||||
if ids.dtype == torch.float32:
|
||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
replace_prefix = {}
|
||||
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||
|
||||
@ -49,6 +57,7 @@ class SD20(supported_models_base.BASE):
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
@ -62,12 +71,12 @@ class SD20(supported_models_base.BASE):
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix[""] = "cond_stage_model.model."
|
||||
replace_prefix["clip_h"] = "cond_stage_model.model"
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
@ -81,6 +90,7 @@ class SD21UnclipL(SD20):
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": 1536,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "embedder.model.visual."
|
||||
@ -93,6 +103,7 @@ class SD21UnclipH(SD20):
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": 2048,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "embedder.model.visual."
|
||||
@ -104,7 +115,8 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
"use_linear_in_transformer": True,
|
||||
"context_dim": 1280,
|
||||
"adm_in_channels": 2560,
|
||||
"transformer_depth": [0, 4, 4, 0],
|
||||
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
@ -139,9 +151,10 @@ class SDXL(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 2, 10],
|
||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
@ -165,6 +178,7 @@ class SDXL(supported_models_base.BASE):
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
@ -189,5 +203,40 @@ class SDXL(supported_models_base.BASE):
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||
|
||||
class SSD1B(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 4, 4],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL]
|
||||
class SVD_img2vid(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 768,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SVD_img2vid(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
return None
|
||||
|
||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -19,7 +19,7 @@ class BASE:
|
||||
clip_prefix = []
|
||||
clip_vision_prefix = None
|
||||
noise_aug_config = None
|
||||
beta_schedule = "linear"
|
||||
sampling_settings = {}
|
||||
latent_format = latent_formats.LatentFormat
|
||||
|
||||
@classmethod
|
||||
@ -53,6 +53,12 @@ class BASE:
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_vae_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
@ -46,15 +46,16 @@ class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
|
||||
def __init__(self, encoder_path=None, decoder_path=None):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.taesd_encoder = Encoder()
|
||||
self.taesd_decoder = Decoder()
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
if encoder_path is not None:
|
||||
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
if decoder_path is not None:
|
||||
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
@ -65,3 +66,11 @@ class TAESD(nn.Module):
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
def decode(self, x):
|
||||
x_sample = self.taesd_decoder(x * self.vae_scale)
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
return x_sample
|
||||
|
||||
def encode(self, x):
|
||||
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
|
||||
|
||||
@ -170,25 +170,12 @@ UNET_MAP_BASIC = {
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
attention_resolutions = unet_config["attention_resolutions"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"]
|
||||
transformer_depth = unet_config["transformer_depth"][:]
|
||||
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
||||
num_blocks = len(channel_mult)
|
||||
if isinstance(num_res_blocks, int):
|
||||
num_res_blocks = [num_res_blocks] * num_blocks
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = [transformer_depth] * num_blocks
|
||||
|
||||
transformers_per_layer = []
|
||||
res = 1
|
||||
for i in range(num_blocks):
|
||||
transformers = 0
|
||||
if res in attention_resolutions:
|
||||
transformers = transformer_depth[i]
|
||||
transformers_per_layer.append(transformers)
|
||||
res *= 2
|
||||
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
||||
|
||||
diffusers_unet_map = {}
|
||||
for x in range(num_blocks):
|
||||
@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config):
|
||||
for i in range(num_res_blocks[x]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||
if transformers_per_layer[x] > 0:
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(transformers_per_layer[x]):
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
n += 1
|
||||
@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config):
|
||||
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||
|
||||
num_res_blocks = list(reversed(num_res_blocks))
|
||||
transformers_per_layer = list(reversed(transformers_per_layer))
|
||||
for x in range(num_blocks):
|
||||
n = (num_res_blocks[x] + 1) * x
|
||||
l = num_res_blocks[x] + 1
|
||||
@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||
c += 1
|
||||
if transformers_per_layer[x] > 0:
|
||||
num_transformers = transformer_depth_output.pop()
|
||||
if num_transformers > 0:
|
||||
c += 1
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(transformers_per_layer[x]):
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
if i == l - 1:
|
||||
@ -270,9 +258,17 @@ def set_attr(obj, attr, value):
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
||||
del prev
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
# inplace update tensor instead of replacing it
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
prev.data.copy_(value)
|
||||
|
||||
def get_attr(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs:
|
||||
@ -311,23 +307,25 @@ def bislerp(samples, width, height):
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
def generate_bilinear_data(length_old, length_new):
|
||||
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
|
||||
def generate_bilinear_data(length_old, length_new, device):
|
||||
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
|
||||
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
||||
coords_2[:,:,:,-1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
coords_2 = coords_2.to(torch.int64)
|
||||
return ratios, coords_1, coords_2
|
||||
|
||||
|
||||
orig_dtype = samples.dtype
|
||||
samples = samples.float()
|
||||
n,c,h,w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
#linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
coords_2 = coords_2.expand((n, c, h, -1))
|
||||
ratios = ratios.expand((n, 1, h, -1))
|
||||
@ -340,7 +338,7 @@ def bislerp(samples, width, height):
|
||||
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||
|
||||
#linear h
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
||||
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
||||
@ -351,7 +349,7 @@ def bislerp(samples, width, height):
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||
return result
|
||||
return result.to(orig_dtype)
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
|
||||
@ -16,7 +16,7 @@ class BasicScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -36,7 +36,7 @@ class KarrasScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -54,7 +54,7 @@ class ExponentialScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -73,7 +73,7 @@ class PolyexponentialScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -81,6 +81,25 @@ class PolyexponentialScheduler:
|
||||
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
||||
return (sigmas, )
|
||||
|
||||
class SDTurboScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, model, steps):
|
||||
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
|
||||
sigmas = model.model.model_sampling.sigma(timesteps)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
return (sigmas, )
|
||||
|
||||
class VPScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -92,7 +111,7 @@ class VPScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -109,7 +128,7 @@ class SplitSigmas:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS","SIGMAS")
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -118,6 +137,24 @@ class SplitSigmas:
|
||||
sigmas2 = sigmas[step:]
|
||||
return (sigmas1, sigmas2)
|
||||
|
||||
class FlipSigmas:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"sigmas": ("SIGMAS", ),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, sigmas):
|
||||
sigmas = sigmas.flip(0)
|
||||
if sigmas[0] == 0:
|
||||
sigmas[0] = 0.0001
|
||||
return (sigmas,)
|
||||
|
||||
class KSamplerSelect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -126,12 +163,12 @@ class KSamplerSelect:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, sampler_name):
|
||||
sampler = comfy.samplers.sampler_class(sampler_name)()
|
||||
sampler = comfy.samplers.sampler_object(sampler_name)
|
||||
return (sampler, )
|
||||
|
||||
class SamplerDPMPP_2M_SDE:
|
||||
@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE:
|
||||
sampler_name = "dpmpp_2m_sde"
|
||||
else:
|
||||
sampler_name = "dpmpp_2m_sde_gpu"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})()
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
|
||||
return (sampler, )
|
||||
|
||||
|
||||
@ -169,7 +206,7 @@ class SamplerDPMPP_SDE:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
@ -178,7 +215,7 @@ class SamplerDPMPP_SDE:
|
||||
sampler_name = "dpmpp_sde"
|
||||
else:
|
||||
sampler_name = "dpmpp_sde_gpu"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})()
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
||||
return (sampler, )
|
||||
|
||||
class SamplerCustom:
|
||||
@ -188,7 +225,7 @@ class SamplerCustom:
|
||||
{"model": ("MODEL",),
|
||||
"add_noise": ("BOOLEAN", {"default": True}),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"sampler": ("SAMPLER", ),
|
||||
@ -234,13 +271,15 @@ class SamplerCustom:
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SamplerCustom": SamplerCustom,
|
||||
"BasicScheduler": BasicScheduler,
|
||||
"KarrasScheduler": KarrasScheduler,
|
||||
"ExponentialScheduler": ExponentialScheduler,
|
||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||
"VPScheduler": VPScheduler,
|
||||
"SDTurboScheduler": SDTurboScheduler,
|
||||
"KSamplerSelect": KSamplerSelect,
|
||||
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
|
||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||
"BasicScheduler": BasicScheduler,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
}
|
||||
|
||||
175
comfy_extras/nodes_images.py
Normal file
175
comfy_extras/nodes_images.py
Normal file
@ -0,0 +1,175 @@
|
||||
import nodes
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
class ImageCrop:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "crop"
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def crop(self, image, width, height, x, y):
|
||||
x = min(x, image.shape[2] - 1)
|
||||
y = min(y, image.shape[1] - 1)
|
||||
to_x = width + x
|
||||
to_y = height + y
|
||||
img = image[:,y:to_y, x:to_x, :]
|
||||
return (img,)
|
||||
|
||||
class RepeatImageBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def repeat(self, image, amount):
|
||||
s = image.repeat((amount, 1,1,1))
|
||||
return (s,)
|
||||
|
||||
class SaveAnimatedWEBP:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
methods = {"default": 4, "fastest": 0, "slowest": 6}
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"lossless": ("BOOLEAN", {"default": True}),
|
||||
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
|
||||
"method": (list(s.methods.keys()),),
|
||||
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
||||
method = self.methods.get(method)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
metadata = pil_images[0].getexif()
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
inital_exif = 0x010f
|
||||
for x in extra_pnginfo:
|
||||
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
|
||||
inital_exif -= 1
|
||||
|
||||
if num_frames == 0:
|
||||
num_frames = len(pil_images)
|
||||
|
||||
c = len(pil_images)
|
||||
for i in range(0, c, num_frames):
|
||||
file = f"{filename}_{counter:05}_.webp"
|
||||
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
animated = num_frames != 1
|
||||
return { "ui": { "images": results, "animated": (animated,) } }
|
||||
|
||||
class SaveAnimatedPNG:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
|
||||
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,)} }
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
}
|
||||
@ -1,4 +1,5 @@
|
||||
import comfy.utils
|
||||
import torch
|
||||
|
||||
def reshape_latent_to(target_shape, latent):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
@ -67,8 +68,43 @@ class LatentMultiply:
|
||||
samples_out["samples"] = s1 * multiplier
|
||||
return (samples_out,)
|
||||
|
||||
class LatentInterpolate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",),
|
||||
"samples2": ("LATENT",),
|
||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples1, samples2, ratio):
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
s2 = reshape_latent_to(s1.shape, s2)
|
||||
|
||||
m1 = torch.linalg.vector_norm(s1, dim=(1))
|
||||
m2 = torch.linalg.vector_norm(s2, dim=(1))
|
||||
|
||||
s1 = torch.nan_to_num(s1 / m1)
|
||||
s2 = torch.nan_to_num(s2 / m2)
|
||||
|
||||
t = (s1 * ratio + s2 * (1.0 - ratio))
|
||||
mt = torch.linalg.vector_norm(t, dim=(1))
|
||||
st = torch.nan_to_num(t / mt)
|
||||
|
||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||
return (samples_out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentAdd": LatentAdd,
|
||||
"LatentSubtract": LatentSubtract,
|
||||
"LatentMultiply": LatentMultiply,
|
||||
"LatentInterpolate": LatentInterpolate,
|
||||
}
|
||||
|
||||
205
comfy_extras/nodes_model_advanced.py
Normal file
205
comfy_extras/nodes_model_advanced.py
Normal file
@ -0,0 +1,205 @@
|
||||
import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.model_sampling
|
||||
import torch
|
||||
|
||||
class LCM(comfy.model_sampling.EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
x0 = model_input - model_output * sigma
|
||||
|
||||
sigma_data = 0.5
|
||||
scaled_timestep = timestep * 10.0 #timestep_scaling
|
||||
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
|
||||
return c_out * x0 + c_skip * model_input
|
||||
|
||||
class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
||||
original_timesteps = 50
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sigma_data = 1.0
|
||||
timesteps = 1000
|
||||
beta_start = 0.00085
|
||||
beta_end = 0.012
|
||||
|
||||
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
self.skip_steps = timesteps // self.original_timesteps
|
||||
|
||||
|
||||
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
||||
for x in range(self.original_timesteps):
|
||||
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
|
||||
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
|
||||
|
||||
def sigma(self, timestep):
|
||||
t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["eps", "v_prediction", "lcm"],),
|
||||
"zsnr": ("BOOLEAN", {"default": False}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, sampling, zsnr):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
|
||||
if sampling == "eps":
|
||||
sampling_type = comfy.model_sampling.EPS
|
||||
elif sampling == "v_prediction":
|
||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||
elif sampling == "lcm":
|
||||
sampling_type = LCM
|
||||
sampling_base = ModelSamplingDiscreteDistilled
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced()
|
||||
if zsnr:
|
||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class ModelSamplingContinuousEDM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["v_prediction", "eps"],),
|
||||
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, sampling, sigma_max, sigma_min):
|
||||
m = model.clone()
|
||||
|
||||
if sampling == "eps":
|
||||
sampling_type = comfy.model_sampling.EPS
|
||||
elif sampling == "v_prediction":
|
||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||
|
||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced()
|
||||
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class RescaleCFG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, multiplier):
|
||||
def rescale_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
sigma = args["sigma"]
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
|
||||
x_orig = args["input"]
|
||||
|
||||
#rescale cfg has to be done on v-pred model output
|
||||
x = x_orig / (sigma * sigma + 1.0)
|
||||
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
||||
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
||||
|
||||
#rescalecfg
|
||||
x_cfg = uncond + cond_scale * (cond - uncond)
|
||||
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
|
||||
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
|
||||
|
||||
x_rescaled = x_cfg * (ro_pos / ro_cfg)
|
||||
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
|
||||
|
||||
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(rescale_cfg)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
}
|
||||
53
comfy_extras/nodes_model_downscale.py
Normal file
53
comfy_extras/nodes_model_downscale.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
class PatchModelAddDownscale:
|
||||
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
|
||||
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
||||
"downscale_method": (s.upscale_methods,),
|
||||
"upscale_method": (s.upscale_methods,),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def input_block_patch(h, transformer_options):
|
||||
if transformer_options["block"][1] == block_number:
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if sigma <= sigma_start and sigma >= sigma_end:
|
||||
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
|
||||
return h
|
||||
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
if h.shape[2] != hsp.shape[2]:
|
||||
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
if downscale_after_skip:
|
||||
m.set_model_input_block_patch_after_skip(input_block_patch)
|
||||
else:
|
||||
m.set_model_input_block_patch(input_block_patch)
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PatchModelAddDownscale": PatchModelAddDownscale,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Sampling
|
||||
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
|
||||
}
|
||||
@ -23,7 +23,7 @@ class Blend:
|
||||
"max": 1.0,
|
||||
"step": 0.01
|
||||
}),
|
||||
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
|
||||
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
|
||||
},
|
||||
}
|
||||
|
||||
@ -54,6 +54,8 @@ class Blend:
|
||||
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
||||
elif mode == "soft_light":
|
||||
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
|
||||
elif mode == "difference":
|
||||
return img1 - img2
|
||||
else:
|
||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||
|
||||
@ -126,7 +128,7 @@ class Quantize:
|
||||
"max": 256,
|
||||
"step": 1
|
||||
}),
|
||||
"dither": (["none", "floyd-steinberg"],),
|
||||
"dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
|
||||
},
|
||||
}
|
||||
|
||||
@ -135,19 +137,47 @@ class Quantize:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
|
||||
def bayer(im, pal_im, order):
|
||||
def normalized_bayer_matrix(n):
|
||||
if n == 0:
|
||||
return np.zeros((1,1), "float32")
|
||||
else:
|
||||
q = 4 ** n
|
||||
m = q * normalized_bayer_matrix(n - 1)
|
||||
return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
|
||||
|
||||
num_colors = len(pal_im.getpalette()) // 3
|
||||
spread = 2 * 256 / num_colors
|
||||
bayer_n = int(math.log2(order))
|
||||
bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
|
||||
|
||||
result = torch.from_numpy(np.array(im).astype(np.float32))
|
||||
tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
|
||||
th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
|
||||
tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
|
||||
result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
|
||||
result = result.to(dtype=torch.uint8)
|
||||
|
||||
im = Image.fromarray(result.cpu().numpy())
|
||||
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||
return im
|
||||
|
||||
def quantize(self, image: torch.Tensor, colors: int, dither: str):
|
||||
batch_size, height, width, _ = image.shape
|
||||
result = torch.zeros_like(image)
|
||||
|
||||
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b]
|
||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||
pil_image = Image.fromarray(img, mode='RGB')
|
||||
im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
|
||||
|
||||
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
|
||||
pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||
|
||||
if dither == "none":
|
||||
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||
elif dither == "floyd-steinberg":
|
||||
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
|
||||
elif dither.startswith("bayer"):
|
||||
order = int(dither.split('-')[-1])
|
||||
quantized_image = Quantize.bayer(im, pal_im, order)
|
||||
|
||||
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
||||
result[b] = quantized_array
|
||||
|
||||
@ -4,7 +4,7 @@ class LatentRebatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "latents": ("LATENT",),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
INPUT_IS_LIST = True
|
||||
|
||||
89
comfy_extras/nodes_video_model.py
Normal file
89
comfy_extras/nodes_video_model.py
Normal file
@ -0,0 +1,89 @@
|
||||
import nodes
|
||||
import torch
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
|
||||
|
||||
class ImageOnlyCheckpointLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "loaders/video_models"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (out[0], out[3], out[2])
|
||||
|
||||
|
||||
class SVD_img2vid_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||
"init_image": ("IMAGE",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
|
||||
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
|
||||
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
|
||||
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
|
||||
output = clip_vision.encode_image(init_image)
|
||||
pooled = output.image_embeds.unsqueeze(0)
|
||||
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||
encode_pixels = pixels[:,:,:,:3]
|
||||
if augmentation_level > 0:
|
||||
encode_pixels += torch.randn_like(pixels) * augmentation_level
|
||||
t = vae.encode(encode_pixels)
|
||||
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
|
||||
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
|
||||
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent})
|
||||
|
||||
class VideoLinearCFGGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "sampling/video_models"
|
||||
|
||||
def patch(self, model, min_cfg):
|
||||
def linear_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
|
||||
}
|
||||
17
execution.py
17
execution.py
@ -683,6 +683,7 @@ def validate_prompt(prompt):
|
||||
|
||||
return (True, None, list(good_outputs), node_errors)
|
||||
|
||||
MAXIMUM_HISTORY_SIZE = 10000
|
||||
|
||||
class PromptQueue:
|
||||
def __init__(self, server):
|
||||
@ -715,6 +716,8 @@ class PromptQueue:
|
||||
def task_done(self, item_id, outputs):
|
||||
with self.mutex:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||
self.history.pop(next(iter(self.history)))
|
||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
||||
for o in outputs:
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||
@ -749,10 +752,20 @@ class PromptQueue:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_history(self, prompt_id=None):
|
||||
def get_history(self, prompt_id=None, max_items=None, offset=-1):
|
||||
with self.mutex:
|
||||
if prompt_id is None:
|
||||
return copy.deepcopy(self.history)
|
||||
out = {}
|
||||
i = 0
|
||||
if offset < 0 and max_items is not None:
|
||||
offset = len(self.history) - max_items
|
||||
for k in self.history:
|
||||
if i >= offset:
|
||||
out[k] = self.history[k]
|
||||
if max_items is not None and len(out) >= max_items:
|
||||
break
|
||||
i += 1
|
||||
return out
|
||||
elif prompt_id in self.history:
|
||||
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
|
||||
else:
|
||||
|
||||
@ -38,7 +38,10 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "inp
|
||||
filename_list_cache = {}
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
os.makedirs(input_directory)
|
||||
try:
|
||||
os.makedirs(input_directory)
|
||||
except:
|
||||
print("Failed to create input directory")
|
||||
|
||||
def set_output_directory(output_dir):
|
||||
global output_directory
|
||||
@ -227,8 +230,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
|
||||
full_output_folder = os.path.join(output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
err = "**** ERROR: Saving image outside the output folder is not allowed." + \
|
||||
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \
|
||||
"\n output_dir: " + output_dir + \
|
||||
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
|
||||
print(err)
|
||||
raise Exception(err)
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
|
||||
@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
||||
self.taesd = taesd
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
x_sample = self.taesd.decoder(x0)[0].detach()
|
||||
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
|
||||
x_sample = self.taesd.decode(x0[:1])[0].detach()
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
12
main.py
12
main.py
@ -88,6 +88,7 @@ def cuda_malloc_warning():
|
||||
|
||||
def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
last_gc_collect = 0
|
||||
while True:
|
||||
item, item_id = q.get()
|
||||
execution_start_time = time.perf_counter()
|
||||
@ -97,9 +98,14 @@ def prompt_worker(q, server):
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
print("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
if (current_time - last_gc_collect) > 10.0:
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
last_gc_collect = current_time
|
||||
print("gc collect")
|
||||
|
||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
|
||||
91
nodes.py
91
nodes.py
@ -248,8 +248,8 @@ class ConditioningSetTimestepRange:
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d['start_percent'] = 1.0 - start
|
||||
d['end_percent'] = 1.0 - end
|
||||
d['start_percent'] = start
|
||||
d['end_percent'] = end
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
@ -572,10 +572,69 @@ class LoraLoader:
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class VAELoader:
|
||||
class LoraLoaderModelOnly(LoraLoader):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_lora_model_only"
|
||||
|
||||
def load_lora_model_only(self, model, lora_name, strength_model):
|
||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||
|
||||
class VAELoader:
|
||||
@staticmethod
|
||||
def vae_list():
|
||||
vaes = folder_paths.get_filename_list("vae")
|
||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||
sdxl_taesd_enc = False
|
||||
sdxl_taesd_dec = False
|
||||
sd1_taesd_enc = False
|
||||
sd1_taesd_dec = False
|
||||
|
||||
for v in approx_vaes:
|
||||
if v.startswith("taesd_decoder."):
|
||||
sd1_taesd_dec = True
|
||||
elif v.startswith("taesd_encoder."):
|
||||
sd1_taesd_enc = True
|
||||
elif v.startswith("taesdxl_decoder."):
|
||||
sdxl_taesd_dec = True
|
||||
elif v.startswith("taesdxl_encoder."):
|
||||
sdxl_taesd_enc = True
|
||||
if sd1_taesd_dec and sd1_taesd_enc:
|
||||
vaes.append("taesd")
|
||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||
vaes.append("taesdxl")
|
||||
return vaes
|
||||
|
||||
@staticmethod
|
||||
def load_taesd(name):
|
||||
sd = {}
|
||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||
|
||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||
|
||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
|
||||
for k in enc:
|
||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||
|
||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
|
||||
for k in dec:
|
||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||
|
||||
if name == "taesd":
|
||||
sd["vae_scale"] = torch.tensor(0.18215)
|
||||
elif name == "taesdxl":
|
||||
sd["vae_scale"] = torch.tensor(0.13025)
|
||||
return sd
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae_name": (s.vae_list(), )}}
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
|
||||
@ -583,8 +642,11 @@ class VAELoader:
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
if vae_name in ["taesd", "taesdxl"]:
|
||||
sd = self.load_taesd(vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
return (vae,)
|
||||
|
||||
@ -685,7 +747,7 @@ class ControlNetApplyAdvanced:
|
||||
if prev_cnet in cnets:
|
||||
c_net = cnets[prev_cnet]
|
||||
else:
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
|
||||
c_net.set_previous_controlnet(prev_cnet)
|
||||
cnets[prev_cnet] = c_net
|
||||
|
||||
@ -1218,7 +1280,7 @@ class KSampler:
|
||||
{"model": ("MODEL",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
||||
"positive": ("CONDITIONING", ),
|
||||
@ -1244,7 +1306,7 @@ class KSamplerAdvanced:
|
||||
"add_noise": (["enable", "disable"], ),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
||||
"positive": ("CONDITIONING", ),
|
||||
@ -1275,6 +1337,7 @@ class SaveImage:
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
self.compress_level = 4
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1308,7 +1371,7 @@ class SaveImage:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
@ -1323,6 +1386,7 @@ class PreviewImage(SaveImage):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
self.type = "temp"
|
||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||
self.compress_level = 1
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1654,6 +1718,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
"ConditioningZeroOut": ConditioningZeroOut,
|
||||
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
|
||||
"LoraLoaderModelOnly": LoraLoaderModelOnly,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -1759,7 +1824,7 @@ def load_custom_nodes():
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
possible_modules = os.listdir(os.path.realpath(custom_node_path))
|
||||
if "__pycache__" in possible_modules:
|
||||
possible_modules.remove("__pycache__")
|
||||
|
||||
@ -1799,6 +1864,10 @@ def init_custom_nodes():
|
||||
"nodes_custom_sampler.py",
|
||||
"nodes_subflow.py",
|
||||
"nodes_hypertile.py",
|
||||
"nodes_model_advanced.py",
|
||||
"nodes_model_downscale.py",
|
||||
"nodes_images.py",
|
||||
"nodes_video_model.py",
|
||||
]
|
||||
|
||||
for node_file in extras_files:
|
||||
|
||||
10
server.py
10
server.py
@ -83,7 +83,8 @@ class PromptServer():
|
||||
if args.enable_cors_header:
|
||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||
|
||||
self.app = web.Application(client_max_size=104857600, middlewares=middlewares)
|
||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
self.web_root = os.path.join(os.path.dirname(
|
||||
os.path.realpath(__file__)), "web")
|
||||
@ -433,7 +434,10 @@ class PromptServer():
|
||||
|
||||
@routes.get("/history")
|
||||
async def get_history(request):
|
||||
return web.json_response(self.prompt_queue.get_history())
|
||||
max_items = request.rel_url.query.get("max_items", None)
|
||||
if max_items is not None:
|
||||
max_items = int(max_items)
|
||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
|
||||
|
||||
@routes.get("/history/{prompt_id}")
|
||||
async def get_history(request):
|
||||
@ -590,7 +594,7 @@ class PromptServer():
|
||||
bytesIO = BytesIO()
|
||||
header = struct.pack(">I", type_num)
|
||||
bytesIO.write(header)
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||
|
||||
|
||||
@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
|
||||
* @param { InstanceType<Ez["EzGraph"]> } graph
|
||||
* @param { InstanceType<Ez["EzInput"]> } input
|
||||
* @param { string } widgetType
|
||||
* @param { boolean } hasControlWidget
|
||||
* @param { number } controlWidgetCount
|
||||
* @returns
|
||||
*/
|
||||
async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) {
|
||||
async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
|
||||
// Connect to primitive and ensure its still connected after
|
||||
let primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(input);
|
||||
@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
|
||||
expect(valueWidget.widget.type).toBe(widgetType);
|
||||
|
||||
// Check if control_after_generate should be added
|
||||
if (hasControlWidget) {
|
||||
if (controlWidgetCount) {
|
||||
const controlWidget = primitive.widgets.control_after_generate;
|
||||
expect(controlWidget.widget.type).toBe("combo");
|
||||
if(widgetType === "combo") {
|
||||
const filterWidget = primitive.widgets.control_filter_list;
|
||||
expect(filterWidget.widget.type).toBe("string");
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we dont have other widgets
|
||||
expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget);
|
||||
expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
|
||||
});
|
||||
|
||||
return primitive;
|
||||
@ -55,8 +59,8 @@ describe("widget inputs", () => {
|
||||
});
|
||||
|
||||
[
|
||||
{ name: "int", type: "INT", widget: "number", control: true },
|
||||
{ name: "float", type: "FLOAT", widget: "number", control: true },
|
||||
{ name: "int", type: "INT", widget: "number", control: 1 },
|
||||
{ name: "float", type: "FLOAT", widget: "number", control: 1 },
|
||||
{ name: "text", type: "STRING" },
|
||||
{
|
||||
name: "customtext",
|
||||
@ -64,7 +68,7 @@ describe("widget inputs", () => {
|
||||
opt: { multiline: true },
|
||||
},
|
||||
{ name: "toggle", type: "BOOLEAN" },
|
||||
{ name: "combo", type: ["a", "b", "c"], control: true },
|
||||
{ name: "combo", type: ["a", "b", "c"], control: 2 },
|
||||
].forEach((c) => {
|
||||
test(`widget conversion + primitive works on ${c.name}`, async () => {
|
||||
const { ez, graph } = await start({
|
||||
@ -106,7 +110,7 @@ describe("widget inputs", () => {
|
||||
n.widgets.ckpt_name.convertToInput();
|
||||
expect(n.inputs.length).toEqual(inputCount + 1);
|
||||
|
||||
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true);
|
||||
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
|
||||
|
||||
// Disconnect & reconnect
|
||||
primitive.outputs[0].connections[0].disconnect();
|
||||
@ -226,7 +230,7 @@ describe("widget inputs", () => {
|
||||
// Reload and ensure it still only has 1 converted widget
|
||||
if (!assertNotNullOrUndefined(input)) return;
|
||||
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", true);
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
|
||||
n = graph.find(n);
|
||||
expect(n.widgets).toHaveLength(1);
|
||||
w = n.widgets.example;
|
||||
@ -258,7 +262,7 @@ describe("widget inputs", () => {
|
||||
|
||||
// Reload and ensure it still only has 1 converted widget
|
||||
if (assertNotNullOrUndefined(input)) {
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", true);
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
|
||||
n = graph.find(n);
|
||||
expect(n.widgets).toHaveLength(1);
|
||||
expect(n.widgets.example.isConvertedToInput).toBeTruthy();
|
||||
@ -316,4 +320,76 @@ describe("widget inputs", () => {
|
||||
n1.outputs[0].connectTo(n2.inputs[0]);
|
||||
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
|
||||
});
|
||||
|
||||
test("combo primitive can filter list when control_after_generate called", async () => {
|
||||
const { ez } = await start({
|
||||
mockNodeDefs: {
|
||||
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
|
||||
},
|
||||
});
|
||||
|
||||
const n1 = ez.TestNode1();
|
||||
n1.widgets.example.convertToInput();
|
||||
const p = ez.PrimitiveNode()
|
||||
p.outputs[0].connectTo(n1.inputs[0]);
|
||||
|
||||
const value = p.widgets.value;
|
||||
const control = p.widgets.control_after_generate.widget;
|
||||
const filter = p.widgets.control_filter_list;
|
||||
|
||||
expect(p.widgets.length).toBe(3);
|
||||
control.value = "increment";
|
||||
expect(value.value).toBe("A");
|
||||
|
||||
// Manually trigger after queue when set to increment
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
|
||||
// Filter to items containing D
|
||||
filter.value = "D";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("D");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("DD");
|
||||
|
||||
// Check decrement
|
||||
value.value = "BBB";
|
||||
control.value = "decrement";
|
||||
filter.value = "B";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("BB");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
|
||||
// Check regex works
|
||||
value.value = "BBB";
|
||||
filter.value = "/[AB]|^C$/";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("AAA");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("BB");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("AA");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("C");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("A");
|
||||
|
||||
// Check random
|
||||
control.value = "randomize";
|
||||
filter.value = "/D/";
|
||||
for(let i = 0; i < 100; i++) {
|
||||
control["afterQueued"]();
|
||||
expect(value.value === "D" || value.value === "DD").toBeTruthy();
|
||||
}
|
||||
|
||||
// Ensure it doesnt apply when fixed
|
||||
control.value = "fixed";
|
||||
value.value = "B";
|
||||
filter.value = "C";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
});
|
||||
});
|
||||
|
||||
@ -174,6 +174,213 @@ const colorPalettes = {
|
||||
"tr-odd-bg-color": "#073642",
|
||||
}
|
||||
},
|
||||
},
|
||||
"arc": {
|
||||
"id": "arc",
|
||||
"name": "Arc",
|
||||
"colors": {
|
||||
"node_slot": {
|
||||
"BOOLEAN": "",
|
||||
"CLIP": "#eacb8b",
|
||||
"CLIP_VISION": "#A8DADC",
|
||||
"CLIP_VISION_OUTPUT": "#ad7452",
|
||||
"CONDITIONING": "#cf876f",
|
||||
"CONTROL_NET": "#00d78d",
|
||||
"CONTROL_NET_WEIGHTS": "",
|
||||
"FLOAT": "",
|
||||
"GLIGEN": "",
|
||||
"IMAGE": "#80a1c0",
|
||||
"IMAGEUPLOAD": "",
|
||||
"INT": "",
|
||||
"LATENT": "#b38ead",
|
||||
"LATENT_KEYFRAME": "",
|
||||
"MASK": "#a3bd8d",
|
||||
"MODEL": "#8978a7",
|
||||
"SAMPLER": "",
|
||||
"SIGMAS": "",
|
||||
"STRING": "",
|
||||
"STYLE_MODEL": "#C2FFAE",
|
||||
"T2I_ADAPTER_WEIGHTS": "",
|
||||
"TAESD": "#DCC274",
|
||||
"TIMESTEP_KEYFRAME": "",
|
||||
"UPSCALE_MODEL": "",
|
||||
"VAE": "#be616b"
|
||||
},
|
||||
"litegraph_base": {
|
||||
"BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAYAAABw4pVUAAAACXBIWXMAAAsTAAALEwEAmpwYAAABcklEQVR4nO3YMUoDARgF4RfxBqZI6/0vZqFn0MYtrLIQMFN8U6V4LAtD+Jm9XG/v30OGl2e/AP7yevz4+vx45nvgF/+QGITEICQGITEIiUFIjNNC3q43u3/YnRJyPOzeQ+0e220nhRzReC8e7R7bbdvl+Jal1Bs46jEIiUFIDEJiEBKDkBhKPbZT6qHdptRTu02p53DUYxASg5AYhMQgJAYhMZR6bKfUQ7tNqad2m1LP4ajHICQGITEIiUFIDEJiKPXYTqmHdptST+02pZ7DUY9BSAxCYhASg5AYhMRQ6rGdUg/tNqWe2m1KPYejHoOQGITEICQGITEIiaHUYzulHtptSj2125R6Dkc9BiExCIlBSAxCYhASQ6nHdko9tNuUemq3KfUcjnoMQmIQEoOQGITEICSGUo/tlHpotyn11G5T6jkc9RiExCAkBiExCIlBSAylHtsp9dBuU+qp3abUczjqMQiJQUgMQmIQEoOQGITE+AHFISNQrFTGuwAAAABJRU5ErkJggg==",
|
||||
"CLEAR_BACKGROUND_COLOR": "#2b2f38",
|
||||
"NODE_TITLE_COLOR": "#b2b7bd",
|
||||
"NODE_SELECTED_TITLE_COLOR": "#FFF",
|
||||
"NODE_TEXT_SIZE": 14,
|
||||
"NODE_TEXT_COLOR": "#AAA",
|
||||
"NODE_SUBTEXT_SIZE": 12,
|
||||
"NODE_DEFAULT_COLOR": "#2b2f38",
|
||||
"NODE_DEFAULT_BGCOLOR": "#242730",
|
||||
"NODE_DEFAULT_BOXCOLOR": "#6e7581",
|
||||
"NODE_DEFAULT_SHAPE": "box",
|
||||
"NODE_BOX_OUTLINE_COLOR": "#FFF",
|
||||
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
|
||||
"DEFAULT_GROUP_FONT": 22,
|
||||
"WIDGET_BGCOLOR": "#2b2f38",
|
||||
"WIDGET_OUTLINE_COLOR": "#6e7581",
|
||||
"WIDGET_TEXT_COLOR": "#DDD",
|
||||
"WIDGET_SECONDARY_TEXT_COLOR": "#b2b7bd",
|
||||
"LINK_COLOR": "#9A9",
|
||||
"EVENT_LINK_COLOR": "#A86",
|
||||
"CONNECTING_LINK_COLOR": "#AFA"
|
||||
},
|
||||
"comfy_base": {
|
||||
"fg-color": "#fff",
|
||||
"bg-color": "#2b2f38",
|
||||
"comfy-menu-bg": "#242730",
|
||||
"comfy-input-bg": "#2b2f38",
|
||||
"input-text": "#ddd",
|
||||
"descrip-text": "#b2b7bd",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#6e7581",
|
||||
"tr-even-bg-color": "#2b2f38",
|
||||
"tr-odd-bg-color": "#242730"
|
||||
}
|
||||
},
|
||||
},
|
||||
"nord": {
|
||||
"id": "nord",
|
||||
"name": "Nord",
|
||||
"colors": {
|
||||
"node_slot": {
|
||||
"BOOLEAN": "",
|
||||
"CLIP": "#eacb8b",
|
||||
"CLIP_VISION": "#A8DADC",
|
||||
"CLIP_VISION_OUTPUT": "#ad7452",
|
||||
"CONDITIONING": "#cf876f",
|
||||
"CONTROL_NET": "#00d78d",
|
||||
"CONTROL_NET_WEIGHTS": "",
|
||||
"FLOAT": "",
|
||||
"GLIGEN": "",
|
||||
"IMAGE": "#80a1c0",
|
||||
"IMAGEUPLOAD": "",
|
||||
"INT": "",
|
||||
"LATENT": "#b38ead",
|
||||
"LATENT_KEYFRAME": "",
|
||||
"MASK": "#a3bd8d",
|
||||
"MODEL": "#8978a7",
|
||||
"SAMPLER": "",
|
||||
"SIGMAS": "",
|
||||
"STRING": "",
|
||||
"STYLE_MODEL": "#C2FFAE",
|
||||
"T2I_ADAPTER_WEIGHTS": "",
|
||||
"TAESD": "#DCC274",
|
||||
"TIMESTEP_KEYFRAME": "",
|
||||
"UPSCALE_MODEL": "",
|
||||
"VAE": "#be616b"
|
||||
},
|
||||
"litegraph_base": {
|
||||
"BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFu2lUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4gPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iQWRvYmUgWE1QIENvcmUgOS4xLWMwMDEgNzkuMTQ2Mjg5OSwgMjAyMy8wNi8yNS0yMDowMTo1NSAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvIiB4bWxuczpkYz0iaHR0cDovL3B1cmwub3JnL2RjL2VsZW1lbnRzLzEuMS8iIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIgeG1sbnM6eG1wTU09Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9tbS8iIHhtbG5zOnN0RXZ0PSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VFdmVudCMiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIiB4bXA6Q3JlYXRlRGF0ZT0iMjAyMy0xMS0xM1QwMDoxODowMiswMTowMCIgeG1wOk1vZGlmeURhdGU9IjIwMjMtMTEtMTVUMDE6MjA6NDUrMDE6MDAiIHhtcDpNZXRhZGF0YURhdGU9IjIwMjMtMTEtMTVUMDE6MjA6NDUrMDE6MDAiIGRjOmZvcm1hdD0iaW1hZ2UvcG5nIiBwaG90b3Nob3A6Q29sb3JNb2RlPSIzIiB4bXBNTTpJbnN0YW5jZUlEPSJ4bXAuaWlkOjUwNDFhMmZjLTEzNzQtMTk0ZC1hZWY4LTYxMzM1MTVmNjUwMCIgeG1wTU06RG9jdW1lbnRJRD0ieG1wLmRpZDoyMzFiMTBiMC1iNGZiLTAyNGUtYjEyZS0zMDUzMDNjZDA3YzgiIHhtcE1NOk9yaWdpbmFsRG9jdW1lbnRJRD0ieG1wLmRpZDoyMzFiMTBiMC1iNGZiLTAyNGUtYjEyZS0zMDUzMDNjZDA3YzgiPiA8eG1wTU06SGlzdG9yeT4gPHJkZjpTZXE+IDxyZGY6bGkgc3RFdnQ6YWN0aW9uPSJjcmVhdGVkIiBzdEV2dDppbnN0YW5jZUlEPSJ4bXAuaWlkOjIzMWIxMGIwLWI0ZmItMDI0ZS1iMTJlLTMwNTMwM2NkMDdjOCIgc3RFdnQ6d2hlbj0iMjAyMy0xMS0xM1QwMDoxODowMiswMTowMCIgc3RFdnQ6c29mdHdhcmVBZ2VudD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIi8+IDxyZGY6bGkgc3RFdnQ6YWN0aW9uPSJzYXZlZCIgc3RFdnQ6aW5zdGFuY2VJRD0ieG1wLmlpZDo1MDQxYTJmYy0xMzc0LTE5NGQtYWVmOC02MTMzNTE1ZjY1MDAiIHN0RXZ0OndoZW49IjIwMjMtMTEtMTVUMDE6MjA6NDUrMDE6MDAiIHN0RXZ0OnNvZnR3YXJlQWdlbnQ9IkFkb2JlIFBob3Rvc2hvcCAyNS4xIChXaW5kb3dzKSIgc3RFdnQ6Y2hhbmdlZD0iLyIvPiA8L3JkZjpTZXE+IDwveG1wTU06SGlzdG9yeT4gPC9yZGY6RGVzY3JpcHRpb24+IDwvcmRmOlJERj4gPC94OnhtcG1ldGE+IDw/eHBhY2tldCBlbmQ9InIiPz73jWg/AAAAyUlEQVR42u3WKwoAIBRFQRdiMb1idv9Lsxn9gEFw4Dbb8JCTojbbXEJwjJVL2HKwYMGCBQuWLbDmjr+9zrBGjHl1WVcvy2DBggULFizTWQpewSt4HzwsgwULFiwFr7MUvMtS8D54WLBgGSxYCl7BK3iXZbBgwYIFC5bpLAWv4BW8Dx6WwYIFC5aC11kK3mUpeB88LFiwDBYsBa/gFbzLMliwYMGCBct0loJX8AreBw/LYMGCBUvB6ywF77IUvA8eFixYBgsWrNfWAZPltufdad+1AAAAAElFTkSuQmCC",
|
||||
"CLEAR_BACKGROUND_COLOR": "#212732",
|
||||
"NODE_TITLE_COLOR": "#999",
|
||||
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
|
||||
"NODE_TEXT_SIZE": 14,
|
||||
"NODE_TEXT_COLOR": "#bcc2c8",
|
||||
"NODE_SUBTEXT_SIZE": 12,
|
||||
"NODE_DEFAULT_COLOR": "#2e3440",
|
||||
"NODE_DEFAULT_BGCOLOR": "#161b22",
|
||||
"NODE_DEFAULT_BOXCOLOR": "#545d70",
|
||||
"NODE_DEFAULT_SHAPE": "box",
|
||||
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
|
||||
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
|
||||
"DEFAULT_GROUP_FONT": 24,
|
||||
"WIDGET_BGCOLOR": "#2e3440",
|
||||
"WIDGET_OUTLINE_COLOR": "#545d70",
|
||||
"WIDGET_TEXT_COLOR": "#bcc2c8",
|
||||
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
|
||||
"LINK_COLOR": "#9A9",
|
||||
"EVENT_LINK_COLOR": "#A86",
|
||||
"CONNECTING_LINK_COLOR": "#AFA"
|
||||
},
|
||||
"comfy_base": {
|
||||
"fg-color": "#e5eaf0",
|
||||
"bg-color": "#2e3440",
|
||||
"comfy-menu-bg": "#161b22",
|
||||
"comfy-input-bg": "#2e3440",
|
||||
"input-text": "#bcc2c8",
|
||||
"descrip-text": "#999",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#545d70",
|
||||
"tr-even-bg-color": "#2e3440",
|
||||
"tr-odd-bg-color": "#161b22"
|
||||
}
|
||||
},
|
||||
},
|
||||
"github": {
|
||||
"id": "github",
|
||||
"name": "Github",
|
||||
"colors": {
|
||||
"node_slot": {
|
||||
"BOOLEAN": "",
|
||||
"CLIP": "#eacb8b",
|
||||
"CLIP_VISION": "#A8DADC",
|
||||
"CLIP_VISION_OUTPUT": "#ad7452",
|
||||
"CONDITIONING": "#cf876f",
|
||||
"CONTROL_NET": "#00d78d",
|
||||
"CONTROL_NET_WEIGHTS": "",
|
||||
"FLOAT": "",
|
||||
"GLIGEN": "",
|
||||
"IMAGE": "#80a1c0",
|
||||
"IMAGEUPLOAD": "",
|
||||
"INT": "",
|
||||
"LATENT": "#b38ead",
|
||||
"LATENT_KEYFRAME": "",
|
||||
"MASK": "#a3bd8d",
|
||||
"MODEL": "#8978a7",
|
||||
"SAMPLER": "",
|
||||
"SIGMAS": "",
|
||||
"STRING": "",
|
||||
"STYLE_MODEL": "#C2FFAE",
|
||||
"T2I_ADAPTER_WEIGHTS": "",
|
||||
"TAESD": "#DCC274",
|
||||
"TIMESTEP_KEYFRAME": "",
|
||||
"UPSCALE_MODEL": "",
|
||||
"VAE": "#be616b"
|
||||
},
|
||||
"litegraph_base": {
|
||||
"BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGlmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4gPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iQWRvYmUgWE1QIENvcmUgOS4xLWMwMDEgNzkuMTQ2Mjg5OSwgMjAyMy8wNi8yNS0yMDowMTo1NSAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvIiB4bWxuczpkYz0iaHR0cDovL3B1cmwub3JnL2RjL2VsZW1lbnRzLzEuMS8iIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIgeG1sbnM6eG1wTU09Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9tbS8iIHhtbG5zOnN0RXZ0PSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VFdmVudCMiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIiB4bXA6Q3JlYXRlRGF0ZT0iMjAyMy0xMS0xM1QwMDoxODowMiswMTowMCIgeG1wOk1vZGlmeURhdGU9IjIwMjMtMTEtMTVUMDI6MDQ6NTkrMDE6MDAiIHhtcDpNZXRhZGF0YURhdGU9IjIwMjMtMTEtMTVUMDI6MDQ6NTkrMDE6MDAiIGRjOmZvcm1hdD0iaW1hZ2UvcG5nIiBwaG90b3Nob3A6Q29sb3JNb2RlPSIzIiB4bXBNTTpJbnN0YW5jZUlEPSJ4bXAuaWlkOmIyYzRhNjA5LWJmYTctYTg0MC1iOGFlLTk3MzE2ZjM1ZGIyNyIgeG1wTU06RG9jdW1lbnRJRD0iYWRvYmU6ZG9jaWQ6cGhvdG9zaG9wOjk0ZmNlZGU4LTE1MTctZmQ0MC04ZGU3LWYzOTgxM2E3ODk5ZiIgeG1wTU06T3JpZ2luYWxEb2N1bWVudElEPSJ4bXAuZGlkOjIzMWIxMGIwLWI0ZmItMDI0ZS1iMTJlLTMwNTMwM2NkMDdjOCI+IDx4bXBNTTpIaXN0b3J5PiA8cmRmOlNlcT4gPHJkZjpsaSBzdEV2dDphY3Rpb249ImNyZWF0ZWQiIHN0RXZ0Omluc3RhbmNlSUQ9InhtcC5paWQ6MjMxYjEwYjAtYjRmYi0wMjRlLWIxMmUtMzA1MzAzY2QwN2M4IiBzdEV2dDp3aGVuPSIyMDIzLTExLTEzVDAwOjE4OjAyKzAxOjAwIiBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZG9iZSBQaG90b3Nob3AgMjUuMSAoV2luZG93cykiLz4gPHJkZjpsaSBzdEV2dDphY3Rpb249InNhdmVkIiBzdEV2dDppbnN0YW5jZUlEPSJ4bXAuaWlkOjQ4OWY1NzlmLTJkNjUtZWQ0Zi04OTg0LTA4NGE2MGE1ZTMzNSIgc3RFdnQ6d2hlbj0iMjAyMy0xMS0xNVQwMjowNDo1OSswMTowMCIgc3RFdnQ6c29mdHdhcmVBZ2VudD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIiBzdEV2dDpjaGFuZ2VkPSIvIi8+IDxyZGY6bGkgc3RFdnQ6YWN0aW9uPSJzYXZlZCIgc3RFdnQ6aW5zdGFuY2VJRD0ieG1wLmlpZDpiMmM0YTYwOS1iZmE3LWE4NDAtYjhhZS05NzMxNmYzNWRiMjciIHN0RXZ0OndoZW49IjIwMjMtMTEtMTVUMDI6MDQ6NTkrMDE6MDAiIHN0RXZ0OnNvZnR3YXJlQWdlbnQ9IkFkb2JlIFBob3Rvc2hvcCAyNS4xIChXaW5kb3dzKSIgc3RFdnQ6Y2hhbmdlZD0iLyIvPiA8L3JkZjpTZXE+IDwveG1wTU06SGlzdG9yeT4gPC9yZGY6RGVzY3JpcHRpb24+IDwvcmRmOlJERj4gPC94OnhtcG1ldGE+IDw/eHBhY2tldCBlbmQ9InIiPz4OTe6GAAAAx0lEQVR42u3WMQoAIQxFwRzJys77X8vSLiRgITif7bYbgrwYc/mKXyBoY4VVBgsWLFiwYFmOlTv+9jfDOjHmr8u6eVkGCxYsWLBgmc5S8ApewXvgYRksWLBgKXidpeBdloL3wMOCBctgwVLwCl7BuyyDBQsWLFiwTGcpeAWv4D3wsAwWLFiwFLzOUvAuS8F74GHBgmWwYCl4Ba/gXZbBggULFixYprMUvIJX8B54WAYLFixYCl5nKXiXpeA98LBgwTJYsGC9tg1o8f4TTtqzNQAAAABJRU5ErkJggg==",
|
||||
"CLEAR_BACKGROUND_COLOR": "#040506",
|
||||
"NODE_TITLE_COLOR": "#999",
|
||||
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
|
||||
"NODE_TEXT_SIZE": 14,
|
||||
"NODE_TEXT_COLOR": "#bcc2c8",
|
||||
"NODE_SUBTEXT_SIZE": 12,
|
||||
"NODE_DEFAULT_COLOR": "#161b22",
|
||||
"NODE_DEFAULT_BGCOLOR": "#13171d",
|
||||
"NODE_DEFAULT_BOXCOLOR": "#30363d",
|
||||
"NODE_DEFAULT_SHAPE": "box",
|
||||
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
|
||||
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
|
||||
"DEFAULT_GROUP_FONT": 24,
|
||||
"WIDGET_BGCOLOR": "#161b22",
|
||||
"WIDGET_OUTLINE_COLOR": "#30363d",
|
||||
"WIDGET_TEXT_COLOR": "#bcc2c8",
|
||||
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
|
||||
"LINK_COLOR": "#9A9",
|
||||
"EVENT_LINK_COLOR": "#A86",
|
||||
"CONNECTING_LINK_COLOR": "#AFA"
|
||||
},
|
||||
"comfy_base": {
|
||||
"fg-color": "#e5eaf0",
|
||||
"bg-color": "#161b22",
|
||||
"comfy-menu-bg": "#13171d",
|
||||
"comfy-input-bg": "#161b22",
|
||||
"input-text": "#bcc2c8",
|
||||
"descrip-text": "#999",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#30363d",
|
||||
"tr-even-bg-color": "#161b22",
|
||||
"tr-odd-bg-color": "#13171d"
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ const ext = {
|
||||
requestAnimationFrame(() => {
|
||||
const currentNode = LGraphCanvas.active_canvas.current_node;
|
||||
const clickedComboValue = currentNode.widgets
|
||||
.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
||||
?.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
||||
.find(w => w.options.values.every((v, i) => v === values[i]))
|
||||
?.value;
|
||||
|
||||
|
||||
@ -14,6 +14,9 @@ import { ComfyDialog, $el } from "../../scripts/ui.js";
|
||||
// To delete/rename:
|
||||
// Right click the canvas
|
||||
// Node templates -> Manage
|
||||
//
|
||||
// To rearrange:
|
||||
// Open the manage dialog and Drag and drop elements using the "Name:" label as handle
|
||||
|
||||
const id = "Comfy.NodeTemplates";
|
||||
|
||||
@ -22,6 +25,10 @@ class ManageTemplates extends ComfyDialog {
|
||||
super();
|
||||
this.element.classList.add("comfy-manage-templates");
|
||||
this.templates = this.load();
|
||||
this.draggedEl = null;
|
||||
this.saveVisualCue = null;
|
||||
this.emptyImg = new Image();
|
||||
this.emptyImg.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs=';
|
||||
|
||||
this.importInput = $el("input", {
|
||||
type: "file",
|
||||
@ -35,14 +42,11 @@ class ManageTemplates extends ComfyDialog {
|
||||
|
||||
createButtons() {
|
||||
const btns = super.createButtons();
|
||||
btns[0].textContent = "Cancel";
|
||||
btns.unshift(
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Save",
|
||||
onclick: () => this.save(),
|
||||
})
|
||||
);
|
||||
btns[0].textContent = "Close";
|
||||
btns[0].onclick = (e) => {
|
||||
clearTimeout(this.saveVisualCue);
|
||||
this.close();
|
||||
};
|
||||
btns.unshift(
|
||||
$el("button", {
|
||||
type: "button",
|
||||
@ -71,25 +75,6 @@ class ManageTemplates extends ComfyDialog {
|
||||
}
|
||||
}
|
||||
|
||||
save() {
|
||||
// Find all visible inputs and save them as our new list
|
||||
const inputs = this.element.querySelectorAll("input");
|
||||
const updated = [];
|
||||
|
||||
for (let i = 0; i < inputs.length; i++) {
|
||||
const input = inputs[i];
|
||||
if (input.parentElement.style.display !== "none") {
|
||||
const t = this.templates[i];
|
||||
t.name = input.value.trim() || input.getAttribute("data-name");
|
||||
updated.push(t);
|
||||
}
|
||||
}
|
||||
|
||||
this.templates = updated;
|
||||
this.store();
|
||||
this.close();
|
||||
}
|
||||
|
||||
store() {
|
||||
localStorage.setItem(id, JSON.stringify(this.templates));
|
||||
}
|
||||
@ -145,71 +130,155 @@ class ManageTemplates extends ComfyDialog {
|
||||
super.show(
|
||||
$el(
|
||||
"div",
|
||||
{
|
||||
style: {
|
||||
display: "grid",
|
||||
gridTemplateColumns: "1fr auto",
|
||||
gap: "5px",
|
||||
},
|
||||
},
|
||||
this.templates.flatMap((t) => {
|
||||
{},
|
||||
this.templates.flatMap((t,i) => {
|
||||
let nameInput;
|
||||
return [
|
||||
$el(
|
||||
"label",
|
||||
"div",
|
||||
{
|
||||
textContent: "Name: ",
|
||||
dataset: { id: i },
|
||||
className: "tempateManagerRow",
|
||||
style: {
|
||||
display: "grid",
|
||||
gridTemplateColumns: "1fr auto",
|
||||
border: "1px dashed transparent",
|
||||
gap: "5px",
|
||||
backgroundColor: "var(--comfy-menu-bg)"
|
||||
},
|
||||
ondragstart: (e) => {
|
||||
this.draggedEl = e.currentTarget;
|
||||
e.currentTarget.style.opacity = "0.6";
|
||||
e.currentTarget.style.border = "1px dashed yellow";
|
||||
e.dataTransfer.effectAllowed = 'move';
|
||||
e.dataTransfer.setDragImage(this.emptyImg, 0, 0);
|
||||
},
|
||||
ondragend: (e) => {
|
||||
e.target.style.opacity = "1";
|
||||
e.currentTarget.style.border = "1px dashed transparent";
|
||||
e.currentTarget.removeAttribute("draggable");
|
||||
|
||||
// rearrange the elements in the localStorage
|
||||
this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
|
||||
var prev_i = el.dataset.id;
|
||||
|
||||
if ( el == this.draggedEl && prev_i != i ) {
|
||||
[this.templates[i], this.templates[prev_i]] = [this.templates[prev_i], this.templates[i]];
|
||||
}
|
||||
el.dataset.id = i;
|
||||
});
|
||||
this.store();
|
||||
},
|
||||
ondragover: (e) => {
|
||||
e.preventDefault();
|
||||
if ( e.currentTarget == this.draggedEl )
|
||||
return;
|
||||
|
||||
let rect = e.currentTarget.getBoundingClientRect();
|
||||
if (e.clientY > rect.top + rect.height / 2) {
|
||||
e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget.nextSibling);
|
||||
} else {
|
||||
e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget);
|
||||
}
|
||||
}
|
||||
},
|
||||
[
|
||||
$el("input", {
|
||||
value: t.name,
|
||||
dataset: { name: t.name },
|
||||
$: (el) => (nameInput = el),
|
||||
}),
|
||||
$el(
|
||||
"label",
|
||||
{
|
||||
textContent: "Name: ",
|
||||
style: {
|
||||
cursor: "grab",
|
||||
},
|
||||
onmousedown: (e) => {
|
||||
// enable dragging only from the label
|
||||
if (e.target.localName == 'label')
|
||||
e.currentTarget.parentNode.draggable = 'true';
|
||||
}
|
||||
},
|
||||
[
|
||||
$el("input", {
|
||||
value: t.name,
|
||||
dataset: { name: t.name },
|
||||
style: {
|
||||
transitionProperty: 'background-color',
|
||||
transitionDuration: '0s',
|
||||
},
|
||||
onchange: (e) => {
|
||||
clearTimeout(this.saveVisualCue);
|
||||
var el = e.target;
|
||||
var row = el.parentNode.parentNode;
|
||||
this.templates[row.dataset.id].name = el.value.trim() || 'untitled';
|
||||
this.store();
|
||||
el.style.backgroundColor = 'rgb(40, 95, 40)';
|
||||
el.style.transitionDuration = '0s';
|
||||
this.saveVisualCue = setTimeout(function () {
|
||||
el.style.transitionDuration = '.7s';
|
||||
el.style.backgroundColor = 'var(--comfy-input-bg)';
|
||||
}, 15);
|
||||
},
|
||||
onkeypress: (e) => {
|
||||
var el = e.target;
|
||||
clearTimeout(this.saveVisualCue);
|
||||
el.style.transitionDuration = '0s';
|
||||
el.style.backgroundColor = 'var(--comfy-input-bg)';
|
||||
},
|
||||
$: (el) => (nameInput = el),
|
||||
})
|
||||
]
|
||||
),
|
||||
$el(
|
||||
"div",
|
||||
{},
|
||||
[
|
||||
$el("button", {
|
||||
textContent: "Export",
|
||||
style: {
|
||||
fontSize: "12px",
|
||||
fontWeight: "normal",
|
||||
},
|
||||
onclick: (e) => {
|
||||
const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
|
||||
const blob = new Blob([json], {type: "application/json"});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: (nameInput.value || t.name) + ".json",
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
});
|
||||
a.click();
|
||||
setTimeout(function () {
|
||||
a.remove();
|
||||
window.URL.revokeObjectURL(url);
|
||||
}, 0);
|
||||
},
|
||||
}),
|
||||
$el("button", {
|
||||
textContent: "Delete",
|
||||
style: {
|
||||
fontSize: "12px",
|
||||
color: "red",
|
||||
fontWeight: "normal",
|
||||
},
|
||||
onclick: (e) => {
|
||||
const item = e.target.parentNode.parentNode;
|
||||
item.parentNode.removeChild(item);
|
||||
this.templates.splice(item.dataset.id*1, 1);
|
||||
this.store();
|
||||
// update the rows index, setTimeout ensures that the list is updated
|
||||
var that = this;
|
||||
setTimeout(function (){
|
||||
that.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
|
||||
el.dataset.id = i;
|
||||
});
|
||||
}, 0);
|
||||
},
|
||||
}),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
$el(
|
||||
"div",
|
||||
{},
|
||||
[
|
||||
$el("button", {
|
||||
textContent: "Export",
|
||||
style: {
|
||||
fontSize: "12px",
|
||||
fontWeight: "normal",
|
||||
},
|
||||
onclick: (e) => {
|
||||
const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
|
||||
const blob = new Blob([json], {type: "application/json"});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: (nameInput.value || t.name) + ".json",
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
});
|
||||
a.click();
|
||||
setTimeout(function () {
|
||||
a.remove();
|
||||
window.URL.revokeObjectURL(url);
|
||||
}, 0);
|
||||
},
|
||||
}),
|
||||
$el("button", {
|
||||
textContent: "Delete",
|
||||
style: {
|
||||
fontSize: "12px",
|
||||
color: "red",
|
||||
fontWeight: "normal",
|
||||
},
|
||||
onclick: (e) => {
|
||||
nameInput.value = "";
|
||||
e.target.parentElement.style.display = "none";
|
||||
e.target.parentElement.previousElementSibling.style.display = "none";
|
||||
},
|
||||
}),
|
||||
]
|
||||
),
|
||||
)
|
||||
];
|
||||
})
|
||||
)
|
||||
|
||||
150
web/extensions/core/undoRedo.js
Normal file
150
web/extensions/core/undoRedo.js
Normal file
@ -0,0 +1,150 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
|
||||
const MAX_HISTORY = 50;
|
||||
|
||||
let undo = [];
|
||||
let redo = [];
|
||||
let activeState = null;
|
||||
let isOurLoad = false;
|
||||
function checkState() {
|
||||
const currentState = app.graph.serialize();
|
||||
if (!graphEqual(activeState, currentState)) {
|
||||
undo.push(activeState);
|
||||
if (undo.length > MAX_HISTORY) {
|
||||
undo.shift();
|
||||
}
|
||||
activeState = clone(currentState);
|
||||
redo.length = 0;
|
||||
}
|
||||
}
|
||||
|
||||
const loadGraphData = app.loadGraphData;
|
||||
app.loadGraphData = async function () {
|
||||
const v = await loadGraphData.apply(this, arguments);
|
||||
if (isOurLoad) {
|
||||
isOurLoad = false;
|
||||
} else {
|
||||
checkState();
|
||||
}
|
||||
return v;
|
||||
};
|
||||
|
||||
function clone(obj) {
|
||||
try {
|
||||
if (typeof structuredClone !== "undefined") {
|
||||
return structuredClone(obj);
|
||||
}
|
||||
} catch (error) {
|
||||
// structuredClone is stricter than using JSON.parse/stringify so fallback to that
|
||||
}
|
||||
|
||||
return JSON.parse(JSON.stringify(obj));
|
||||
}
|
||||
|
||||
function graphEqual(a, b, root = true) {
|
||||
if (a === b) return true;
|
||||
|
||||
if (typeof a == "object" && a && typeof b == "object" && b) {
|
||||
const keys = Object.getOwnPropertyNames(a);
|
||||
|
||||
if (keys.length != Object.getOwnPropertyNames(b).length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const key of keys) {
|
||||
let av = a[key];
|
||||
let bv = b[key];
|
||||
if (root && key === "nodes") {
|
||||
// Nodes need to be sorted as the order changes when selecting nodes
|
||||
av = [...av].sort((a, b) => a.id - b.id);
|
||||
bv = [...bv].sort((a, b) => a.id - b.id);
|
||||
}
|
||||
if (!graphEqual(av, bv, false)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
const undoRedo = async (e) => {
|
||||
if (e.ctrlKey || e.metaKey) {
|
||||
if (e.key === "y") {
|
||||
const prevState = redo.pop();
|
||||
if (prevState) {
|
||||
undo.push(activeState);
|
||||
isOurLoad = true;
|
||||
await app.loadGraphData(prevState);
|
||||
activeState = prevState;
|
||||
}
|
||||
return true;
|
||||
} else if (e.key === "z") {
|
||||
const prevState = undo.pop();
|
||||
if (prevState) {
|
||||
redo.push(activeState);
|
||||
isOurLoad = true;
|
||||
await app.loadGraphData(prevState);
|
||||
activeState = prevState;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const bindInput = (activeEl) => {
|
||||
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") {
|
||||
for (const evt of ["change", "input", "blur"]) {
|
||||
if (`on${evt}` in activeEl) {
|
||||
const listener = () => {
|
||||
checkState();
|
||||
activeEl.removeEventListener(evt, listener);
|
||||
};
|
||||
activeEl.addEventListener(evt, listener);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener(
|
||||
"keydown",
|
||||
(e) => {
|
||||
requestAnimationFrame(async () => {
|
||||
const activeEl = document.activeElement;
|
||||
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
|
||||
// Ignore events on inputs, they have their native history
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if this is a ctrl+z ctrl+y
|
||||
if (await undoRedo(e)) return;
|
||||
|
||||
// If our active element is some type of input then handle changes after they're done
|
||||
if (bindInput(activeEl)) return;
|
||||
checkState();
|
||||
});
|
||||
},
|
||||
true
|
||||
);
|
||||
|
||||
// Handle clicking DOM elements (e.g. widgets)
|
||||
window.addEventListener("mouseup", () => {
|
||||
checkState();
|
||||
});
|
||||
|
||||
// Handle litegraph clicks
|
||||
const processMouseUp = LGraphCanvas.prototype.processMouseUp;
|
||||
LGraphCanvas.prototype.processMouseUp = function (e) {
|
||||
const v = processMouseUp.apply(this, arguments);
|
||||
checkState();
|
||||
return v;
|
||||
};
|
||||
const processMouseDown = LGraphCanvas.prototype.processMouseDown;
|
||||
LGraphCanvas.prototype.processMouseDown = function (e) {
|
||||
const v = processMouseDown.apply(this, arguments);
|
||||
checkState();
|
||||
return v;
|
||||
};
|
||||
@ -1,4 +1,4 @@
|
||||
import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
|
||||
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
|
||||
import { app } from "../../scripts/app.js";
|
||||
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
@ -467,7 +467,11 @@ app.registerExtension({
|
||||
if (!control_value) {
|
||||
control_value = "fixed";
|
||||
}
|
||||
addValueControlWidget(this, widget, control_value);
|
||||
addValueControlWidgets(this, widget, control_value);
|
||||
let filter = this.widgets_values?.[2];
|
||||
if(filter && this.widgets.length === 3) {
|
||||
this.widgets[2].value = filter;
|
||||
}
|
||||
}
|
||||
|
||||
// When our value changes, update other widgets to reflect our changes
|
||||
|
||||
@ -2533,7 +2533,7 @@
|
||||
var w = this.widgets[i];
|
||||
if(!w)
|
||||
continue;
|
||||
if(w.options && w.options.property && this.properties[ w.options.property ])
|
||||
if(w.options && w.options.property && (this.properties[ w.options.property ] != undefined))
|
||||
w.value = JSON.parse( JSON.stringify( this.properties[ w.options.property ] ) );
|
||||
}
|
||||
if (info.widgets_values) {
|
||||
@ -5717,10 +5717,10 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
* @method enableWebGL
|
||||
**/
|
||||
LGraphCanvas.prototype.enableWebGL = function() {
|
||||
if (typeof GL === undefined) {
|
||||
if (typeof GL === "undefined") {
|
||||
throw "litegl.js must be included to use a WebGL canvas";
|
||||
}
|
||||
if (typeof enableWebGLCanvas === undefined) {
|
||||
if (typeof enableWebGLCanvas === "undefined") {
|
||||
throw "webglCanvas.js must be included to use this feature";
|
||||
}
|
||||
|
||||
@ -7113,15 +7113,16 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
}
|
||||
};
|
||||
|
||||
LGraphCanvas.prototype.copyToClipboard = function() {
|
||||
LGraphCanvas.prototype.copyToClipboard = function(nodes) {
|
||||
var clipboard_info = {
|
||||
nodes: [],
|
||||
links: []
|
||||
};
|
||||
var index = 0;
|
||||
var selected_nodes_array = [];
|
||||
for (var i in this.selected_nodes) {
|
||||
var node = this.selected_nodes[i];
|
||||
if (!nodes) nodes = this.selected_nodes;
|
||||
for (var i in nodes) {
|
||||
var node = nodes[i];
|
||||
if (node.clonable === false)
|
||||
continue;
|
||||
node._relative_id = index;
|
||||
@ -11705,7 +11706,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
default:
|
||||
iS = 0; // try with first if no name set
|
||||
}
|
||||
if (typeof options.node_from.outputs[iS] !== undefined){
|
||||
if (typeof options.node_from.outputs[iS] !== "undefined"){
|
||||
if (iS!==false && iS>-1){
|
||||
options.node_from.connectByType( iS, node, options.node_from.outputs[iS].type );
|
||||
}
|
||||
@ -11733,7 +11734,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
default:
|
||||
iS = 0; // try with first if no name set
|
||||
}
|
||||
if (typeof options.node_to.inputs[iS] !== undefined){
|
||||
if (typeof options.node_to.inputs[iS] !== "undefined"){
|
||||
if (iS!==false && iS>-1){
|
||||
// try connection
|
||||
options.node_to.connectByTypeOutput(iS,node,options.node_to.inputs[iS].type);
|
||||
|
||||
@ -254,9 +254,9 @@ class ComfyApi extends EventTarget {
|
||||
* Gets the prompt execution history
|
||||
* @returns Prompt history including node outputs
|
||||
*/
|
||||
async getHistory() {
|
||||
async getHistory(max_items=200) {
|
||||
try {
|
||||
const res = await this.fetchApi("/history");
|
||||
const res = await this.fetchApi(`/history?max_items=${max_items}`);
|
||||
return { History: Object.values(await res.json()) };
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
||||
@ -3,7 +3,26 @@ import { ComfyWidgets } from "./widgets.js";
|
||||
import { ComfyUI, $el } from "./ui.js";
|
||||
import { api } from "./api.js";
|
||||
import { defaultGraph } from "./defaultGraph.js";
|
||||
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||
import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||
import { addDomClippingSetting } from "./domWidget.js";
|
||||
import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js"
|
||||
|
||||
export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview"
|
||||
|
||||
function sanitizeNodeName(string) {
|
||||
let entityMap = {
|
||||
'&': '',
|
||||
'<': '',
|
||||
'>': '',
|
||||
'"': '',
|
||||
"'": '',
|
||||
'`': '',
|
||||
'=': ''
|
||||
};
|
||||
return String(string).replace(/[&<>"'`=]/g, function fromEntityMap (s) {
|
||||
return entityMap[s];
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
||||
@ -389,7 +408,9 @@ export class ComfyApp {
|
||||
return shiftY;
|
||||
}
|
||||
|
||||
node.prototype.setSizeForImage = function () {
|
||||
node.prototype.setSizeForImage = function (force) {
|
||||
if(!force && this.animatedImages) return;
|
||||
|
||||
if (this.inputHeight) {
|
||||
this.setSize(this.size);
|
||||
return;
|
||||
@ -406,13 +427,20 @@ export class ComfyApp {
|
||||
let imagesChanged = false
|
||||
|
||||
const output = app.nodeOutputs[this.id + ""];
|
||||
if (output && output.images) {
|
||||
if (output?.images) {
|
||||
this.animatedImages = output?.animated?.find(Boolean);
|
||||
if (this.images !== output.images) {
|
||||
this.images = output.images;
|
||||
imagesChanged = true;
|
||||
imgURLs = imgURLs.concat(output.images.map(params => {
|
||||
return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam());
|
||||
}))
|
||||
imgURLs = imgURLs.concat(
|
||||
output.images.map((params) => {
|
||||
return api.apiURL(
|
||||
"/view?" +
|
||||
new URLSearchParams(params).toString() +
|
||||
(this.animatedImages ? "" : app.getPreviewFormatParam())
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -491,7 +519,35 @@ export class ComfyApp {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (this.imgs && this.imgs.length) {
|
||||
if (this.imgs?.length) {
|
||||
const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET);
|
||||
|
||||
if(this.animatedImages) {
|
||||
// Instead of using the canvas we'll use a IMG
|
||||
if(widgetIdx > -1) {
|
||||
// Replace content
|
||||
const widget = this.widgets[widgetIdx];
|
||||
widget.options.host.updateImages(this.imgs);
|
||||
} else {
|
||||
const host = createImageHost(this);
|
||||
this.setSizeForImage(true);
|
||||
const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, {
|
||||
host,
|
||||
getHeight: host.getHeight,
|
||||
onDraw: host.onDraw,
|
||||
hideOnZoom: false
|
||||
});
|
||||
widget.serializeValue = () => undefined;
|
||||
widget.options.host.updateImages(this.imgs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (widgetIdx > -1) {
|
||||
this.widgets[widgetIdx].onRemove?.();
|
||||
this.widgets.splice(widgetIdx, 1);
|
||||
}
|
||||
|
||||
const canvas = app.graph.list_of_graphcanvas[0];
|
||||
const mouse = canvas.graph_mouse;
|
||||
if (!canvas.pointer_is_down && this.pointerDown) {
|
||||
@ -531,31 +587,7 @@ export class ComfyApp {
|
||||
}
|
||||
else {
|
||||
cell_padding = 0;
|
||||
let best = 0;
|
||||
let w = this.imgs[0].naturalWidth;
|
||||
let h = this.imgs[0].naturalHeight;
|
||||
|
||||
// compact style
|
||||
for (let c = 1; c <= numImages; c++) {
|
||||
const rows = Math.ceil(numImages / c);
|
||||
const cW = dw / c;
|
||||
const cH = dh / rows;
|
||||
const scaleX = cW / w;
|
||||
const scaleY = cH / h;
|
||||
|
||||
const scale = Math.min(scaleX, scaleY, 1);
|
||||
const imageW = w * scale;
|
||||
const imageH = h * scale;
|
||||
const area = imageW * imageH * numImages;
|
||||
|
||||
if (area > best) {
|
||||
best = area;
|
||||
cellWidth = imageW;
|
||||
cellHeight = imageH;
|
||||
cols = c;
|
||||
shiftX = c * ((cW - imageW) / 2);
|
||||
}
|
||||
}
|
||||
({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh));
|
||||
}
|
||||
|
||||
let anyHovered = false;
|
||||
@ -1256,6 +1288,7 @@ export class ComfyApp {
|
||||
canvasEl.tabIndex = "1";
|
||||
document.body.prepend(canvasEl);
|
||||
|
||||
addDomClippingSetting();
|
||||
this.#addProcessMouseHandler();
|
||||
this.#addProcessKeyHandler();
|
||||
this.#addConfigureHandler();
|
||||
@ -1453,6 +1486,17 @@ export class ComfyApp {
|
||||
localStorage.setItem("litegrapheditor_clipboard", old);
|
||||
}
|
||||
|
||||
showMissingNodesError(missingNodeTypes, hasAddedNodes = true) {
|
||||
this.ui.dialog.show(
|
||||
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map(
|
||||
(t) => `<li>${t}</li>`
|
||||
).join("")}</ul>${hasAddedNodes ? "Nodes that have failed to load will show as red on the graph." : ""}`
|
||||
);
|
||||
this.logging.addEntry("Comfy.App", "warn", {
|
||||
MissingNodes: missingNodeTypes,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Populates the graph with the specified workflow data
|
||||
* @param {*} graphData A serialized graph object
|
||||
@ -1462,24 +1506,28 @@ export class ComfyApp {
|
||||
|
||||
let reset_invalid_values = false;
|
||||
if (!graphData) {
|
||||
if (typeof structuredClone === "undefined")
|
||||
{
|
||||
graphData = JSON.parse(JSON.stringify(defaultGraph));
|
||||
}else
|
||||
{
|
||||
graphData = structuredClone(defaultGraph);
|
||||
}
|
||||
graphData = defaultGraph;
|
||||
reset_invalid_values = true;
|
||||
}
|
||||
|
||||
if (typeof structuredClone === "undefined")
|
||||
{
|
||||
graphData = JSON.parse(JSON.stringify(graphData));
|
||||
}else
|
||||
{
|
||||
graphData = structuredClone(graphData);
|
||||
}
|
||||
|
||||
const missingNodeTypes = [];
|
||||
for (let n of graphData.nodes) {
|
||||
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
|
||||
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
|
||||
if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix
|
||||
if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix
|
||||
|
||||
// Find missing node types
|
||||
if (!(n.type in LiteGraph.registered_node_types)) {
|
||||
n.type = sanitizeNodeName(n.type);
|
||||
missingNodeTypes.push(n.type);
|
||||
}
|
||||
}
|
||||
@ -1570,14 +1618,7 @@ export class ComfyApp {
|
||||
}
|
||||
|
||||
if (missingNodeTypes.length) {
|
||||
this.ui.dialog.show(
|
||||
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map(
|
||||
(t) => `<li>${t}</li>`
|
||||
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
|
||||
);
|
||||
this.logging.addEntry("Comfy.App", "warn", {
|
||||
MissingNodes: missingNodeTypes,
|
||||
});
|
||||
this.showMissingNodesError(missingNodeTypes);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1589,7 +1630,17 @@ export class ComfyApp {
|
||||
* @returns The workflow and node links
|
||||
*/
|
||||
async graphToPrompt(graph=this.graph, nodeIdOffset=0, path="/", globalMappings={}) {
|
||||
const workflow = graph.serialize();
|
||||
for (const node of this.graph.computeExecutionOrder(false)) {
|
||||
if (node.isVirtualNode) {
|
||||
// Don't serialize frontend only nodes but let them make changes
|
||||
if (node.applyToGraph) {
|
||||
node.applyToGraph();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const workflow = this.graph.serialize();
|
||||
const output = {};
|
||||
|
||||
let childNodeCount = 0;
|
||||
@ -1598,10 +1649,6 @@ export class ComfyApp {
|
||||
const n = workflow.nodes.find((n) => n.id === node.id);
|
||||
|
||||
if (node.isVirtualNode) {
|
||||
// Don't serialize frontend only nodes but let them make changes
|
||||
if (node.applyToGraph) {
|
||||
node.applyToGraph(workflow);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1862,12 +1909,23 @@ export class ComfyApp {
|
||||
importA1111(this.graph, pngInfo.parameters);
|
||||
}
|
||||
}
|
||||
} else if (file.type === "image/webp") {
|
||||
const pngInfo = await getWebpMetadata(file);
|
||||
if (pngInfo) {
|
||||
if (pngInfo.workflow) {
|
||||
this.loadGraphData(JSON.parse(pngInfo.workflow));
|
||||
} else if (pngInfo.Workflow) {
|
||||
this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node.
|
||||
}
|
||||
}
|
||||
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => {
|
||||
var jsonContent = JSON.parse(reader.result);
|
||||
const jsonContent = JSON.parse(reader.result);
|
||||
if (jsonContent?.templates) {
|
||||
this.loadTemplateData(jsonContent);
|
||||
} else if(this.isApiJson(jsonContent)) {
|
||||
this.loadApiJson(jsonContent);
|
||||
} else {
|
||||
this.loadGraphData(jsonContent);
|
||||
}
|
||||
@ -1881,6 +1939,51 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
|
||||
isApiJson(data) {
|
||||
return Object.values(data).every((v) => v.class_type);
|
||||
}
|
||||
|
||||
loadApiJson(apiData) {
|
||||
const missingNodeTypes = Object.values(apiData).filter((n) => !LiteGraph.registered_node_types[n.class_type]);
|
||||
if (missingNodeTypes.length) {
|
||||
this.showMissingNodesError(missingNodeTypes.map(t => t.class_type), false);
|
||||
return;
|
||||
}
|
||||
|
||||
const ids = Object.keys(apiData);
|
||||
app.graph.clear();
|
||||
for (const id of ids) {
|
||||
const data = apiData[id];
|
||||
const node = LiteGraph.createNode(data.class_type);
|
||||
node.id = isNaN(+id) ? id : +id;
|
||||
graph.add(node);
|
||||
}
|
||||
|
||||
for (const id of ids) {
|
||||
const data = apiData[id];
|
||||
const node = app.graph.getNodeById(id);
|
||||
for (const input in data.inputs ?? {}) {
|
||||
const value = data.inputs[input];
|
||||
if (value instanceof Array) {
|
||||
const [fromId, fromSlot] = value;
|
||||
const fromNode = app.graph.getNodeById(fromId);
|
||||
const toSlot = node.inputs?.findIndex((inp) => inp.name === input);
|
||||
if (toSlot !== -1) {
|
||||
fromNode.connect(fromSlot, node, toSlot);
|
||||
}
|
||||
} else {
|
||||
const widget = node.widgets?.find((w) => w.name === input);
|
||||
if (widget) {
|
||||
widget.value = value;
|
||||
widget.callback?.(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
app.graph.arrange();
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a Comfy web extension with the app
|
||||
* @param {ComfyExtension} extension
|
||||
|
||||
323
web/scripts/domWidget.js
Normal file
323
web/scripts/domWidget.js
Normal file
@ -0,0 +1,323 @@
|
||||
import { app, ANIM_PREVIEW_WIDGET } from "./app.js";
|
||||
|
||||
const SIZE = Symbol();
|
||||
|
||||
function intersect(a, b) {
|
||||
const x = Math.max(a.x, b.x);
|
||||
const num1 = Math.min(a.x + a.width, b.x + b.width);
|
||||
const y = Math.max(a.y, b.y);
|
||||
const num2 = Math.min(a.y + a.height, b.y + b.height);
|
||||
if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y];
|
||||
else return null;
|
||||
}
|
||||
|
||||
function getClipPath(node, element, elRect) {
|
||||
const selectedNode = Object.values(app.canvas.selected_nodes)[0];
|
||||
if (selectedNode && selectedNode !== node) {
|
||||
const MARGIN = 7;
|
||||
const scale = app.canvas.ds.scale;
|
||||
|
||||
const bounding = selectedNode.getBounding();
|
||||
const intersection = intersect(
|
||||
{ x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale },
|
||||
{
|
||||
x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN,
|
||||
y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN,
|
||||
width: bounding[2] + MARGIN + MARGIN,
|
||||
height: bounding[3] + MARGIN + MARGIN,
|
||||
}
|
||||
);
|
||||
|
||||
if (!intersection) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const widgetRect = element.getBoundingClientRect();
|
||||
const clipX = intersection[0] - widgetRect.x / scale + "px";
|
||||
const clipY = intersection[1] - widgetRect.y / scale + "px";
|
||||
const clipWidth = intersection[2] + "px";
|
||||
const clipHeight = intersection[3] + "px";
|
||||
const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`;
|
||||
return path;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
function computeSize(size) {
|
||||
if (this.widgets?.[0].last_y == null) return;
|
||||
|
||||
let y = this.widgets[0].last_y;
|
||||
let freeSpace = size[1] - y;
|
||||
|
||||
let widgetHeight = 0;
|
||||
let dom = [];
|
||||
for (const w of this.widgets) {
|
||||
if (w.type === "converted-widget") {
|
||||
// Ignore
|
||||
delete w.computedHeight;
|
||||
} else if (w.computeSize) {
|
||||
widgetHeight += w.computeSize()[1] + 4;
|
||||
} else if (w.element) {
|
||||
// Extract DOM widget size info
|
||||
const styles = getComputedStyle(w.element);
|
||||
let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height"));
|
||||
let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height"));
|
||||
|
||||
let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height");
|
||||
if (prefHeight.endsWith?.("%")) {
|
||||
prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100);
|
||||
} else {
|
||||
prefHeight = parseInt(prefHeight);
|
||||
if (isNaN(minHeight)) {
|
||||
minHeight = prefHeight;
|
||||
}
|
||||
}
|
||||
if (isNaN(minHeight)) {
|
||||
minHeight = 50;
|
||||
}
|
||||
if (!isNaN(maxHeight)) {
|
||||
if (!isNaN(prefHeight)) {
|
||||
prefHeight = Math.min(prefHeight, maxHeight);
|
||||
} else {
|
||||
prefHeight = maxHeight;
|
||||
}
|
||||
}
|
||||
dom.push({
|
||||
minHeight,
|
||||
prefHeight,
|
||||
w,
|
||||
});
|
||||
} else {
|
||||
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
}
|
||||
|
||||
freeSpace -= widgetHeight;
|
||||
|
||||
// Calculate sizes with all widgets at their min height
|
||||
const prefGrow = []; // Nodes that want to grow to their prefd size
|
||||
const canGrow = []; // Nodes that can grow to auto size
|
||||
let growBy = 0;
|
||||
for (const d of dom) {
|
||||
freeSpace -= d.minHeight;
|
||||
if (isNaN(d.prefHeight)) {
|
||||
canGrow.push(d);
|
||||
d.w.computedHeight = d.minHeight;
|
||||
} else {
|
||||
const diff = d.prefHeight - d.minHeight;
|
||||
if (diff > 0) {
|
||||
prefGrow.push(d);
|
||||
growBy += diff;
|
||||
d.diff = diff;
|
||||
} else {
|
||||
d.w.computedHeight = d.minHeight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) {
|
||||
// Allocate space for image
|
||||
freeSpace -= 220;
|
||||
}
|
||||
|
||||
if (freeSpace < 0) {
|
||||
// Not enough space for all widgets so we need to grow
|
||||
size[1] -= freeSpace;
|
||||
this.graph.setDirtyCanvas(true);
|
||||
} else {
|
||||
// Share the space between each
|
||||
const growDiff = freeSpace - growBy;
|
||||
if (growDiff > 0) {
|
||||
// All pref sizes can be fulfilled
|
||||
freeSpace = growDiff;
|
||||
for (const d of prefGrow) {
|
||||
d.w.computedHeight = d.prefHeight;
|
||||
}
|
||||
} else {
|
||||
// We need to grow evenly
|
||||
const shared = -growDiff / prefGrow.length;
|
||||
for (const d of prefGrow) {
|
||||
d.w.computedHeight = d.prefHeight - shared;
|
||||
}
|
||||
freeSpace = 0;
|
||||
}
|
||||
|
||||
if (freeSpace > 0 && canGrow.length) {
|
||||
// Grow any that are auto height
|
||||
const shared = freeSpace / canGrow.length;
|
||||
for (const d of canGrow) {
|
||||
d.w.computedHeight += shared;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Position each of the widgets
|
||||
for (const w of this.widgets) {
|
||||
w.y = y;
|
||||
if (w.computedHeight) {
|
||||
y += w.computedHeight;
|
||||
} else if (w.computeSize) {
|
||||
y += w.computeSize()[1] + 4;
|
||||
} else {
|
||||
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen
|
||||
const elementWidgets = new Set();
|
||||
const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes;
|
||||
LGraphCanvas.prototype.computeVisibleNodes = function () {
|
||||
const visibleNodes = computeVisibleNodes.apply(this, arguments);
|
||||
for (const node of app.graph._nodes) {
|
||||
if (elementWidgets.has(node)) {
|
||||
const hidden = visibleNodes.indexOf(node) === -1;
|
||||
for (const w of node.widgets) {
|
||||
if (w.element) {
|
||||
w.element.hidden = hidden;
|
||||
if (hidden) {
|
||||
w.options.onHide?.(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return visibleNodes;
|
||||
};
|
||||
|
||||
let enableDomClipping = true;
|
||||
|
||||
export function addDomClippingSetting() {
|
||||
app.ui.settings.addSetting({
|
||||
id: "Comfy.DOMClippingEnabled",
|
||||
name: "Enable DOM element clipping (enabling may reduce performance)",
|
||||
type: "boolean",
|
||||
defaultValue: enableDomClipping,
|
||||
onChange(value) {
|
||||
console.log("enableDomClipping", enableDomClipping);
|
||||
enableDomClipping = !!value;
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
|
||||
options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options };
|
||||
|
||||
if (!element.parentElement) {
|
||||
document.body.append(element);
|
||||
}
|
||||
|
||||
let mouseDownHandler;
|
||||
if (element.blur) {
|
||||
mouseDownHandler = (event) => {
|
||||
if (!element.contains(event.target)) {
|
||||
element.blur();
|
||||
}
|
||||
};
|
||||
document.addEventListener("mousedown", mouseDownHandler);
|
||||
}
|
||||
|
||||
const widget = {
|
||||
type,
|
||||
name,
|
||||
get value() {
|
||||
return options.getValue?.() ?? undefined;
|
||||
},
|
||||
set value(v) {
|
||||
options.setValue?.(v);
|
||||
widget.callback?.(widget.value);
|
||||
},
|
||||
draw: function (ctx, node, widgetWidth, y, widgetHeight) {
|
||||
if (widget.computedHeight == null) {
|
||||
computeSize.call(node, node.size);
|
||||
}
|
||||
|
||||
const hidden =
|
||||
node.flags?.collapsed ||
|
||||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
|
||||
widget.computedHeight <= 0 ||
|
||||
widget.type === "converted-widget";
|
||||
element.hidden = hidden;
|
||||
element.style.display = hidden ? "none" : null;
|
||||
if (hidden) {
|
||||
widget.options.onHide?.(widget);
|
||||
return;
|
||||
}
|
||||
|
||||
const margin = 10;
|
||||
const elRect = ctx.canvas.getBoundingClientRect();
|
||||
const transform = new DOMMatrix()
|
||||
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
|
||||
.multiplySelf(ctx.getTransform())
|
||||
.translateSelf(margin, margin + y);
|
||||
|
||||
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d);
|
||||
|
||||
Object.assign(element.style, {
|
||||
transformOrigin: "0 0",
|
||||
transform: scale,
|
||||
left: `${transform.a + transform.e}px`,
|
||||
top: `${transform.d + transform.f}px`,
|
||||
width: `${widgetWidth - margin * 2}px`,
|
||||
height: `${(widget.computedHeight ?? 50) - margin * 2}px`,
|
||||
position: "absolute",
|
||||
zIndex: app.graph._nodes.indexOf(node),
|
||||
});
|
||||
|
||||
if (enableDomClipping) {
|
||||
element.style.clipPath = getClipPath(node, element, elRect);
|
||||
element.style.willChange = "clip-path";
|
||||
}
|
||||
|
||||
this.options.onDraw?.(widget);
|
||||
},
|
||||
element,
|
||||
options,
|
||||
onRemove() {
|
||||
if (mouseDownHandler) {
|
||||
document.removeEventListener("mousedown", mouseDownHandler);
|
||||
}
|
||||
element.remove();
|
||||
},
|
||||
};
|
||||
|
||||
for (const evt of options.selectOn) {
|
||||
element.addEventListener(evt, () => {
|
||||
app.canvas.selectNode(this);
|
||||
app.canvas.bringToFront(this);
|
||||
});
|
||||
}
|
||||
|
||||
this.addCustomWidget(widget);
|
||||
elementWidgets.add(this);
|
||||
|
||||
const collapse = this.collapse;
|
||||
this.collapse = function() {
|
||||
collapse.apply(this, arguments);
|
||||
if(this.flags?.collapsed) {
|
||||
element.hidden = true;
|
||||
element.style.display = "none";
|
||||
}
|
||||
}
|
||||
|
||||
const onRemoved = this.onRemoved;
|
||||
this.onRemoved = function () {
|
||||
element.remove();
|
||||
elementWidgets.delete(this);
|
||||
onRemoved?.apply(this, arguments);
|
||||
};
|
||||
|
||||
if (!this[SIZE]) {
|
||||
this[SIZE] = true;
|
||||
const onResize = this.onResize;
|
||||
this.onResize = function (size) {
|
||||
options.beforeResize?.call(widget, this);
|
||||
computeSize.call(this, size);
|
||||
onResize?.apply(this, arguments);
|
||||
options.afterResize?.call(widget, this);
|
||||
};
|
||||
}
|
||||
|
||||
return widget;
|
||||
};
|
||||
@ -24,7 +24,7 @@ export function getPngMetadata(file) {
|
||||
const length = dataView.getUint32(offset);
|
||||
// Get the chunk type
|
||||
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
|
||||
if (type === "tEXt") {
|
||||
if (type === "tEXt" || type == "comf") {
|
||||
// Get the keyword
|
||||
let keyword_end = offset + 8;
|
||||
while (pngData[keyword_end] !== 0) {
|
||||
@ -47,6 +47,105 @@ export function getPngMetadata(file) {
|
||||
});
|
||||
}
|
||||
|
||||
function parseExifData(exifData) {
|
||||
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
|
||||
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
|
||||
|
||||
// Function to read 16-bit and 32-bit integers from binary data
|
||||
function readInt(offset, isLittleEndian, length) {
|
||||
let arr = exifData.slice(offset, offset + length)
|
||||
if (length === 2) {
|
||||
return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint16(0, isLittleEndian);
|
||||
} else if (length === 4) {
|
||||
return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint32(0, isLittleEndian);
|
||||
}
|
||||
}
|
||||
|
||||
// Read the offset to the first IFD (Image File Directory)
|
||||
const ifdOffset = readInt(4, isLittleEndian, 4);
|
||||
|
||||
function parseIFD(offset) {
|
||||
const numEntries = readInt(offset, isLittleEndian, 2);
|
||||
const result = {};
|
||||
|
||||
for (let i = 0; i < numEntries; i++) {
|
||||
const entryOffset = offset + 2 + i * 12;
|
||||
const tag = readInt(entryOffset, isLittleEndian, 2);
|
||||
const type = readInt(entryOffset + 2, isLittleEndian, 2);
|
||||
const numValues = readInt(entryOffset + 4, isLittleEndian, 4);
|
||||
const valueOffset = readInt(entryOffset + 8, isLittleEndian, 4);
|
||||
|
||||
// Read the value(s) based on the data type
|
||||
let value;
|
||||
if (type === 2) {
|
||||
// ASCII string
|
||||
value = String.fromCharCode(...exifData.slice(valueOffset, valueOffset + numValues - 1));
|
||||
}
|
||||
|
||||
result[tag] = value;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Parse the first IFD
|
||||
const ifdData = parseIFD(ifdOffset);
|
||||
return ifdData;
|
||||
}
|
||||
|
||||
function splitValues(input) {
|
||||
var output = {};
|
||||
for (var key in input) {
|
||||
var value = input[key];
|
||||
var splitValues = value.split(':', 2);
|
||||
output[splitValues[0]] = splitValues[1];
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
export function getWebpMetadata(file) {
|
||||
return new Promise((r) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
const webp = new Uint8Array(event.target.result);
|
||||
const dataView = new DataView(webp.buffer);
|
||||
|
||||
// Check that the WEBP signature is present
|
||||
if (dataView.getUint32(0) !== 0x52494646 || dataView.getUint32(8) !== 0x57454250) {
|
||||
console.error("Not a valid WEBP file");
|
||||
r();
|
||||
return;
|
||||
}
|
||||
|
||||
// Start searching for chunks after the WEBP signature
|
||||
let offset = 12;
|
||||
let txt_chunks = {};
|
||||
// Loop through the chunks in the WEBP file
|
||||
while (offset < webp.length) {
|
||||
const chunk_length = dataView.getUint32(offset + 4, true);
|
||||
const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4));
|
||||
if (chunk_type === "EXIF") {
|
||||
if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") {
|
||||
offset += 6;
|
||||
}
|
||||
let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length));
|
||||
for (var key in data) {
|
||||
var value = data[key];
|
||||
let index = value.indexOf(':');
|
||||
txt_chunks[value.slice(0, index)] = value.slice(index + 1);
|
||||
}
|
||||
}
|
||||
|
||||
offset += 8 + chunk_length;
|
||||
}
|
||||
|
||||
r(txt_chunks);
|
||||
};
|
||||
|
||||
reader.readAsArrayBuffer(file);
|
||||
});
|
||||
}
|
||||
|
||||
export function getLatentMetadata(file) {
|
||||
return new Promise((r) => {
|
||||
const reader = new FileReader();
|
||||
|
||||
@ -599,7 +599,7 @@ export class ComfyUI {
|
||||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
accept: ".json,image/png,.latent,.safetensors",
|
||||
accept: ".json,image/png,.latent,.safetensors,image/webp",
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
onchange: () => {
|
||||
@ -719,20 +719,22 @@ export class ComfyUI {
|
||||
filename += ".json";
|
||||
}
|
||||
}
|
||||
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
|
||||
const blob = new Blob([json], {type: "application/json"});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: filename,
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
app.graphToPrompt().then(p=>{
|
||||
const json = JSON.stringify(p.workflow, null, 2); // convert the data to a JSON string
|
||||
const blob = new Blob([json], {type: "application/json"});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: filename,
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
});
|
||||
a.click();
|
||||
setTimeout(function () {
|
||||
a.remove();
|
||||
window.URL.revokeObjectURL(url);
|
||||
}, 0);
|
||||
});
|
||||
a.click();
|
||||
setTimeout(function () {
|
||||
a.remove();
|
||||
window.URL.revokeObjectURL(url);
|
||||
}, 0);
|
||||
},
|
||||
}),
|
||||
$el("button", {
|
||||
|
||||
97
web/scripts/ui/imagePreview.js
Normal file
97
web/scripts/ui/imagePreview.js
Normal file
@ -0,0 +1,97 @@
|
||||
import { $el } from "../ui.js";
|
||||
|
||||
export function calculateImageGrid(imgs, dw, dh) {
|
||||
let best = 0;
|
||||
let w = imgs[0].naturalWidth;
|
||||
let h = imgs[0].naturalHeight;
|
||||
const numImages = imgs.length;
|
||||
|
||||
let cellWidth, cellHeight, cols, rows, shiftX;
|
||||
// compact style
|
||||
for (let c = 1; c <= numImages; c++) {
|
||||
const r = Math.ceil(numImages / c);
|
||||
const cW = dw / c;
|
||||
const cH = dh / r;
|
||||
const scaleX = cW / w;
|
||||
const scaleY = cH / h;
|
||||
|
||||
const scale = Math.min(scaleX, scaleY, 1);
|
||||
const imageW = w * scale;
|
||||
const imageH = h * scale;
|
||||
const area = imageW * imageH * numImages;
|
||||
|
||||
if (area > best) {
|
||||
best = area;
|
||||
cellWidth = imageW;
|
||||
cellHeight = imageH;
|
||||
cols = c;
|
||||
rows = r;
|
||||
shiftX = c * ((cW - imageW) / 2);
|
||||
}
|
||||
}
|
||||
|
||||
return { cellWidth, cellHeight, cols, rows, shiftX };
|
||||
}
|
||||
|
||||
export function createImageHost(node) {
|
||||
const el = $el("div.comfy-img-preview");
|
||||
let currentImgs;
|
||||
let first = true;
|
||||
|
||||
function updateSize() {
|
||||
let w = null;
|
||||
let h = null;
|
||||
|
||||
if (currentImgs) {
|
||||
let elH = el.clientHeight;
|
||||
if (first) {
|
||||
first = false;
|
||||
// On first run, if we are small then grow a bit
|
||||
if (elH < 190) {
|
||||
elH = 190;
|
||||
}
|
||||
el.style.setProperty("--comfy-widget-min-height", elH);
|
||||
} else {
|
||||
el.style.setProperty("--comfy-widget-min-height", null);
|
||||
}
|
||||
|
||||
const nw = node.size[0];
|
||||
({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH));
|
||||
w += "px";
|
||||
h += "px";
|
||||
|
||||
el.style.setProperty("--comfy-img-preview-width", w);
|
||||
el.style.setProperty("--comfy-img-preview-height", h);
|
||||
}
|
||||
}
|
||||
return {
|
||||
el,
|
||||
updateImages(imgs) {
|
||||
if (imgs !== currentImgs) {
|
||||
if (currentImgs == null) {
|
||||
requestAnimationFrame(() => {
|
||||
updateSize();
|
||||
});
|
||||
}
|
||||
el.replaceChildren(...imgs);
|
||||
currentImgs = imgs;
|
||||
node.onResize(node.size);
|
||||
node.graph.setDirtyCanvas(true, true);
|
||||
}
|
||||
},
|
||||
getHeight() {
|
||||
updateSize();
|
||||
},
|
||||
onDraw() {
|
||||
// Element from point uses a hittest find elements so we need to toggle pointer events
|
||||
el.style.pointerEvents = "all";
|
||||
const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]);
|
||||
el.style.pointerEvents = "none";
|
||||
|
||||
if(!over) return;
|
||||
// Set the overIndex so Open Image etc work
|
||||
const idx = currentImgs.indexOf(over);
|
||||
node.overIndex = idx;
|
||||
},
|
||||
};
|
||||
}
|
||||
@ -1,5 +1,6 @@
|
||||
import { api } from "./api.js"
|
||||
import { getPngMetadata } from "./pnginfo.js";
|
||||
import "./domWidget.js";
|
||||
|
||||
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
|
||||
let defaultVal = inputData[1]["default"];
|
||||
@ -24,17 +25,58 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
|
||||
}
|
||||
|
||||
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) {
|
||||
const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, {
|
||||
const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, {
|
||||
addFilterList: false,
|
||||
});
|
||||
return widgets[0];
|
||||
}
|
||||
|
||||
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) {
|
||||
if (!options) options = {};
|
||||
|
||||
const widgets = [];
|
||||
const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, {
|
||||
values: ["fixed", "increment", "decrement", "randomize"],
|
||||
serialize: false, // Don't include this in prompt.
|
||||
});
|
||||
valueControl.afterQueued = () => {
|
||||
widgets.push(valueControl);
|
||||
|
||||
const isCombo = targetWidget.type === "combo";
|
||||
let comboFilter;
|
||||
if (isCombo && options.addFilterList !== false) {
|
||||
comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, {
|
||||
serialize: false, // Don't include this in prompt.
|
||||
});
|
||||
widgets.push(comboFilter);
|
||||
}
|
||||
|
||||
valueControl.afterQueued = () => {
|
||||
var v = valueControl.value;
|
||||
|
||||
if (targetWidget.type == "combo" && v !== "fixed") {
|
||||
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
|
||||
let current_length = targetWidget.options.values.length;
|
||||
if (isCombo && v !== "fixed") {
|
||||
let values = targetWidget.options.values;
|
||||
const filter = comboFilter?.value;
|
||||
if (filter) {
|
||||
let check;
|
||||
if (filter.startsWith("/") && filter.endsWith("/")) {
|
||||
try {
|
||||
const regex = new RegExp(filter.substring(1, filter.length - 1));
|
||||
check = (item) => regex.test(item);
|
||||
} catch (error) {
|
||||
console.error("Error constructing RegExp filter for node " + node.id, filter, error);
|
||||
}
|
||||
}
|
||||
if (!check) {
|
||||
const lower = filter.toLocaleLowerCase();
|
||||
check = (item) => item.toLocaleLowerCase().includes(lower);
|
||||
}
|
||||
values = values.filter(item => check(item));
|
||||
if (!values.length && targetWidget.options.values.length) {
|
||||
console.warn("Filter for node " + node.id + " has filtered out all items", filter);
|
||||
}
|
||||
}
|
||||
let current_index = values.indexOf(targetWidget.value);
|
||||
let current_length = values.length;
|
||||
|
||||
switch (v) {
|
||||
case "increment":
|
||||
@ -51,7 +93,7 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
||||
current_index = Math.max(0, current_index);
|
||||
current_index = Math.min(current_length - 1, current_index);
|
||||
if (current_index >= 0) {
|
||||
let value = targetWidget.options.values[current_index];
|
||||
let value = values[current_index];
|
||||
targetWidget.value = value;
|
||||
targetWidget.callback(value);
|
||||
}
|
||||
@ -88,7 +130,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
||||
targetWidget.callback(targetWidget.value);
|
||||
}
|
||||
}
|
||||
return valueControl;
|
||||
|
||||
return widgets;
|
||||
};
|
||||
|
||||
function seedWidget(node, inputName, inputData, app) {
|
||||
@ -98,166 +141,21 @@ function seedWidget(node, inputName, inputData, app) {
|
||||
seed.widget.linkedWidgets = [seedControl];
|
||||
return seed;
|
||||
}
|
||||
|
||||
const MultilineSymbol = Symbol();
|
||||
const MultilineResizeSymbol = Symbol();
|
||||
|
||||
function addMultilineWidget(node, name, opts, app) {
|
||||
const MIN_SIZE = 50;
|
||||
const inputEl = document.createElement("textarea");
|
||||
inputEl.className = "comfy-multiline-input";
|
||||
inputEl.value = opts.defaultVal;
|
||||
inputEl.placeholder = opts.placeholder || "";
|
||||
|
||||
function computeSize(size) {
|
||||
if (node.widgets[0].last_y == null) return;
|
||||
|
||||
let y = node.widgets[0].last_y;
|
||||
let freeSpace = size[1] - y;
|
||||
|
||||
// Compute the height of all non customtext widgets
|
||||
let widgetHeight = 0;
|
||||
const multi = [];
|
||||
for (let i = 0; i < node.widgets.length; i++) {
|
||||
const w = node.widgets[i];
|
||||
if (w.type === "customtext") {
|
||||
multi.push(w);
|
||||
} else {
|
||||
if (w.computeSize) {
|
||||
widgetHeight += w.computeSize()[1] + 4;
|
||||
} else {
|
||||
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See how large each text input can be
|
||||
freeSpace -= widgetHeight;
|
||||
freeSpace /= multi.length + (!!node.imgs?.length);
|
||||
|
||||
if (freeSpace < MIN_SIZE) {
|
||||
// There isnt enough space for all the widgets, increase the size of the node
|
||||
freeSpace = MIN_SIZE;
|
||||
node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
|
||||
node.graph.setDirtyCanvas(true);
|
||||
}
|
||||
|
||||
// Position each of the widgets
|
||||
for (const w of node.widgets) {
|
||||
w.y = y;
|
||||
if (w.type === "customtext") {
|
||||
y += freeSpace;
|
||||
w.computedHeight = freeSpace - multi.length*4;
|
||||
} else if (w.computeSize) {
|
||||
y += w.computeSize()[1] + 4;
|
||||
} else {
|
||||
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
}
|
||||
|
||||
node.inputHeight = freeSpace;
|
||||
}
|
||||
|
||||
const widget = {
|
||||
type: "customtext",
|
||||
name,
|
||||
get value() {
|
||||
return this.inputEl.value;
|
||||
const widget = node.addDOMWidget(name, "customtext", inputEl, {
|
||||
getValue() {
|
||||
return inputEl.value;
|
||||
},
|
||||
set value(x) {
|
||||
this.inputEl.value = x;
|
||||
setValue(v) {
|
||||
inputEl.value = v;
|
||||
},
|
||||
draw: function (ctx, _, widgetWidth, y, widgetHeight) {
|
||||
if (!this.parent.inputHeight) {
|
||||
// If we are initially offscreen when created we wont have received a resize event
|
||||
// Calculate it here instead
|
||||
computeSize(node.size);
|
||||
}
|
||||
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
|
||||
const margin = 10;
|
||||
const elRect = ctx.canvas.getBoundingClientRect();
|
||||
const transform = new DOMMatrix()
|
||||
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
|
||||
.multiplySelf(ctx.getTransform())
|
||||
.translateSelf(margin, margin + y);
|
||||
|
||||
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d)
|
||||
Object.assign(this.inputEl.style, {
|
||||
transformOrigin: "0 0",
|
||||
transform: scale,
|
||||
left: `${transform.a + transform.e}px`,
|
||||
top: `${transform.d + transform.f}px`,
|
||||
width: `${widgetWidth - (margin * 2)}px`,
|
||||
height: `${this.parent.inputHeight - (margin * 2)}px`,
|
||||
position: "absolute",
|
||||
background: (!node.color)?'':node.color,
|
||||
color: (!node.color)?'':'white',
|
||||
zIndex: app.graph._nodes.indexOf(node),
|
||||
});
|
||||
this.inputEl.hidden = !visible;
|
||||
},
|
||||
};
|
||||
widget.inputEl = document.createElement("textarea");
|
||||
widget.inputEl.className = "comfy-multiline-input";
|
||||
widget.inputEl.value = opts.defaultVal;
|
||||
widget.inputEl.placeholder = opts.placeholder || "";
|
||||
document.addEventListener("mousedown", function (event) {
|
||||
if (!widget.inputEl.contains(event.target)) {
|
||||
widget.inputEl.blur();
|
||||
}
|
||||
});
|
||||
widget.parent = node;
|
||||
document.body.appendChild(widget.inputEl);
|
||||
|
||||
node.addCustomWidget(widget);
|
||||
|
||||
app.canvas.onDrawBackground = function () {
|
||||
// Draw node isnt fired once the node is off the screen
|
||||
// if it goes off screen quickly, the input may not be removed
|
||||
// this shifts it off screen so it can be moved back if the node is visible.
|
||||
for (let n in app.graph._nodes) {
|
||||
n = graph._nodes[n];
|
||||
for (let w in n.widgets) {
|
||||
let wid = n.widgets[w];
|
||||
if (Object.hasOwn(wid, "inputEl")) {
|
||||
wid.inputEl.style.left = -8000 + "px";
|
||||
wid.inputEl.style.position = "absolute";
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
node.onRemoved = function () {
|
||||
// When removing this node we need to remove the input from the DOM
|
||||
for (let y in this.widgets) {
|
||||
if (this.widgets[y].inputEl) {
|
||||
this.widgets[y].inputEl.remove();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
widget.onRemove = () => {
|
||||
widget.inputEl?.remove();
|
||||
|
||||
// Restore original size handler if we are the last
|
||||
if (!--node[MultilineSymbol]) {
|
||||
node.onResize = node[MultilineResizeSymbol];
|
||||
delete node[MultilineSymbol];
|
||||
delete node[MultilineResizeSymbol];
|
||||
}
|
||||
};
|
||||
|
||||
if (node[MultilineSymbol]) {
|
||||
node[MultilineSymbol]++;
|
||||
} else {
|
||||
node[MultilineSymbol] = 1;
|
||||
const onResize = (node[MultilineResizeSymbol] = node.onResize);
|
||||
|
||||
node.onResize = function (size) {
|
||||
computeSize(size);
|
||||
|
||||
// Call original resizer handler
|
||||
if (onResize) {
|
||||
onResize.apply(this, arguments);
|
||||
}
|
||||
};
|
||||
}
|
||||
widget.inputEl = inputEl;
|
||||
|
||||
return { minWidth: 400, minHeight: 200, widget };
|
||||
}
|
||||
@ -306,14 +204,23 @@ export const ComfyWidgets = {
|
||||
};
|
||||
},
|
||||
BOOLEAN(node, inputName, inputData) {
|
||||
let defaultVal = inputData[1]["default"];
|
||||
let defaultVal = false;
|
||||
let options = {};
|
||||
if (inputData[1]) {
|
||||
if (inputData[1].default)
|
||||
defaultVal = inputData[1].default;
|
||||
if (inputData[1].label_on)
|
||||
options["on"] = inputData[1].label_on;
|
||||
if (inputData[1].label_off)
|
||||
options["off"] = inputData[1].label_off;
|
||||
}
|
||||
return {
|
||||
widget: node.addWidget(
|
||||
"toggle",
|
||||
inputName,
|
||||
defaultVal,
|
||||
() => {},
|
||||
{"on": inputData[1].label_on, "off": inputData[1].label_off}
|
||||
options,
|
||||
)
|
||||
};
|
||||
},
|
||||
|
||||
@ -409,6 +409,21 @@ dialog::backdrop {
|
||||
width: calc(100% - 10px);
|
||||
}
|
||||
|
||||
.comfy-img-preview {
|
||||
pointer-events: none;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-content: flex-start;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.comfy-img-preview img {
|
||||
object-fit: contain;
|
||||
width: var(--comfy-img-preview-width);
|
||||
height: var(--comfy-img-preview-height);
|
||||
}
|
||||
|
||||
/* Search box */
|
||||
|
||||
.litegraph.litesearchbox {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user