mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 05:23:37 +08:00
Merge remote-tracking branch 'origin/master' into group-nodes
This commit is contained in:
commit
4c928d2371
@ -27,7 +27,6 @@ class ControlNet(nn.Module):
|
|||||||
model_channels,
|
model_channels,
|
||||||
hint_channels,
|
hint_channels,
|
||||||
num_res_blocks,
|
num_res_blocks,
|
||||||
attention_resolutions,
|
|
||||||
dropout=0,
|
dropout=0,
|
||||||
channel_mult=(1, 2, 4, 8),
|
channel_mult=(1, 2, 4, 8),
|
||||||
conv_resample=True,
|
conv_resample=True,
|
||||||
@ -52,6 +51,7 @@ class ControlNet(nn.Module):
|
|||||||
use_linear_in_transformer=False,
|
use_linear_in_transformer=False,
|
||||||
adm_in_channels=None,
|
adm_in_channels=None,
|
||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
|
transformer_depth_output=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops,
|
operations=comfy.ops,
|
||||||
):
|
):
|
||||||
@ -79,10 +79,7 @@ class ControlNet(nn.Module):
|
|||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
if isinstance(transformer_depth, int):
|
|
||||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
|
||||||
if transformer_depth_middle is None:
|
|
||||||
transformer_depth_middle = transformer_depth[-1]
|
|
||||||
if isinstance(num_res_blocks, int):
|
if isinstance(num_res_blocks, int):
|
||||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||||
else:
|
else:
|
||||||
@ -90,18 +87,16 @@ class ControlNet(nn.Module):
|
|||||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
|
|
||||||
if disable_self_attentions is not None:
|
if disable_self_attentions is not None:
|
||||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||||
assert len(disable_self_attentions) == len(channel_mult)
|
assert len(disable_self_attentions) == len(channel_mult)
|
||||||
if num_attention_blocks is not None:
|
if num_attention_blocks is not None:
|
||||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
|
||||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
|
||||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
|
||||||
f"attention will still not be set.")
|
|
||||||
|
|
||||||
self.attention_resolutions = attention_resolutions
|
transformer_depth = transformer_depth[:]
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.channel_mult = channel_mult
|
self.channel_mult = channel_mult
|
||||||
self.conv_resample = conv_resample
|
self.conv_resample = conv_resample
|
||||||
@ -180,11 +175,14 @@ class ControlNet(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
operations=operations
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
if ds in attention_resolutions:
|
num_transformers = transformer_depth.pop(0)
|
||||||
|
if num_transformers > 0:
|
||||||
if num_head_channels == -1:
|
if num_head_channels == -1:
|
||||||
dim_head = ch // num_heads
|
dim_head = ch // num_heads
|
||||||
else:
|
else:
|
||||||
@ -201,9 +199,9 @@ class ControlNet(nn.Module):
|
|||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
SpatialTransformer(
|
SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, operations=operations
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
@ -223,11 +221,13 @@ class ControlNet(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
down=True,
|
down=True,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
|
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -245,7 +245,7 @@ class ControlNet(nn.Module):
|
|||||||
if legacy:
|
if legacy:
|
||||||
#num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
self.middle_block = TimestepEmbedSequential(
|
mid_block = [
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
@ -253,12 +253,15 @@ class ControlNet(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
),
|
)]
|
||||||
SpatialTransformer( # always uses a self-attn
|
if transformer_depth_middle >= 0:
|
||||||
|
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, operations=operations
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -267,9 +270,11 @@ class ControlNet(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
),
|
)]
|
||||||
)
|
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
|
|||||||
@ -36,6 +36,8 @@ parser = argparse.ArgumentParser()
|
|||||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||||
|
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||||
|
|
||||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
|
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
|
||||||
from .utils import load_torch_file, transformers_convert
|
from .utils import load_torch_file, transformers_convert, common_upscale
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import contextlib
|
import contextlib
|
||||||
@ -7,6 +7,18 @@ import contextlib
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
def clip_preprocess(image, size=224):
|
||||||
|
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||||
|
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||||
|
scale = (size / min(image.shape[1], image.shape[2]))
|
||||||
|
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
|
||||||
|
h = (image.shape[2] - size)//2
|
||||||
|
w = (image.shape[3] - size)//2
|
||||||
|
image = image[:,:,h:h+size,w:w+size]
|
||||||
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
@ -23,25 +35,12 @@ class ClipVisionModel():
|
|||||||
self.model.to(self.dtype)
|
self.model.to(self.dtype)
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
self.processor = CLIPImageProcessor(crop_size=224,
|
|
||||||
do_center_crop=True,
|
|
||||||
do_convert_rgb=True,
|
|
||||||
do_normalize=True,
|
|
||||||
do_resize=True,
|
|
||||||
image_mean=[ 0.48145466,0.4578275,0.40821073],
|
|
||||||
image_std=[0.26862954,0.26130258,0.27577711],
|
|
||||||
resample=3, #bicubic
|
|
||||||
size=224)
|
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.model.load_state_dict(sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def encode_image(self, image):
|
def encode_image(self, image):
|
||||||
img = torch.clip((255. * image), 0, 255).round().int()
|
|
||||||
img = list(map(lambda a: a, img))
|
|
||||||
inputs = self.processor(images=img, return_tensors="pt")
|
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = inputs['pixel_values'].to(self.load_device)
|
pixel_values = clip_preprocess(image.to(self.load_device))
|
||||||
|
|
||||||
if self.dtype != torch.float32:
|
if self.dtype != torch.float32:
|
||||||
precision_scope = torch.autocast
|
precision_scope = torch.autocast
|
||||||
|
|||||||
64
comfy/conds.py
Normal file
64
comfy/conds.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import enum
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||||
|
return abs(a*b) // math.gcd(a, b)
|
||||||
|
|
||||||
|
class CONDRegular:
|
||||||
|
def __init__(self, cond):
|
||||||
|
self.cond = cond
|
||||||
|
|
||||||
|
def _copy_with(self, cond):
|
||||||
|
return self.__class__(cond)
|
||||||
|
|
||||||
|
def process_cond(self, batch_size, device, **kwargs):
|
||||||
|
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||||
|
|
||||||
|
def can_concat(self, other):
|
||||||
|
if self.cond.shape != other.cond.shape:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def concat(self, others):
|
||||||
|
conds = [self.cond]
|
||||||
|
for x in others:
|
||||||
|
conds.append(x.cond)
|
||||||
|
return torch.cat(conds)
|
||||||
|
|
||||||
|
class CONDNoiseShape(CONDRegular):
|
||||||
|
def process_cond(self, batch_size, device, area, **kwargs):
|
||||||
|
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
|
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||||
|
|
||||||
|
|
||||||
|
class CONDCrossAttn(CONDRegular):
|
||||||
|
def can_concat(self, other):
|
||||||
|
s1 = self.cond.shape
|
||||||
|
s2 = other.cond.shape
|
||||||
|
if s1 != s2:
|
||||||
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||||
|
return False
|
||||||
|
|
||||||
|
mult_min = lcm(s1[1], s2[1])
|
||||||
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def concat(self, others):
|
||||||
|
conds = [self.cond]
|
||||||
|
crossattn_max_len = self.cond.shape[1]
|
||||||
|
for x in others:
|
||||||
|
c = x.cond
|
||||||
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||||
|
conds.append(c)
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for c in conds:
|
||||||
|
if c.shape[1] < crossattn_max_len:
|
||||||
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||||
|
out.append(c)
|
||||||
|
return torch.cat(out)
|
||||||
@ -132,6 +132,7 @@ class ControlNet(ControlBase):
|
|||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
self.model_sampling_current = None
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -156,10 +157,13 @@ class ControlNet(ControlBase):
|
|||||||
|
|
||||||
|
|
||||||
context = cond['c_crossattn']
|
context = cond['c_crossattn']
|
||||||
y = cond.get('c_adm', None)
|
y = cond.get('y', None)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
y = y.to(self.control_model.dtype)
|
y = y.to(self.control_model.dtype)
|
||||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
|
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
||||||
return self.control_merge(None, control, control_prev, output_dtype)
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
|
|||||||
out.append(self.control_model_wrapped)
|
out.append(self.control_model_wrapped)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
self.model_sampling_current = model.model_sampling
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
self.model_sampling_current = None
|
||||||
|
super().cleanup()
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module):
|
class Linear(torch.nn.Module):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
|
|||||||
@ -852,6 +852,12 @@ class SigmaConvert:
|
|||||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||||
return log_mean_coeff - log_std
|
return log_mean_coeff - log_std
|
||||||
|
|
||||||
|
def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
||||||
|
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
|
||||||
|
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
||||||
|
return (input - model(input, sigma_in, **kwargs)) / sigma
|
||||||
|
|
||||||
|
|
||||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||||
timesteps = sigmas.clone()
|
timesteps = sigmas.clone()
|
||||||
if sigmas[-1] == 0:
|
if sigmas[-1] == 0:
|
||||||
@ -874,14 +880,14 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
model_type = "noise"
|
model_type = "noise"
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
model_fn = model_wrapper(
|
||||||
model.predict_eps_sigma,
|
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
|
||||||
ns,
|
ns,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
model_kwargs=extra_args,
|
model_kwargs=extra_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
order = min(3, len(timesteps) - 1)
|
order = min(3, len(timesteps) - 2)
|
||||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
||||||
x /= ns.marginal_alpha(timesteps[-1])
|
x /= ns.marginal_alpha(timesteps[-1])
|
||||||
|
|||||||
@ -1,194 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from . import sampling, utils
|
|
||||||
|
|
||||||
|
|
||||||
class VDenoiser(nn.Module):
|
|
||||||
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
|
|
||||||
|
|
||||||
def __init__(self, inner_model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = inner_model
|
|
||||||
self.sigma_data = 1.
|
|
||||||
|
|
||||||
def get_scalings(self, sigma):
|
|
||||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
|
||||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
return c_skip, c_out, c_in
|
|
||||||
|
|
||||||
def sigma_to_t(self, sigma):
|
|
||||||
return sigma.atan() / math.pi * 2
|
|
||||||
|
|
||||||
def t_to_sigma(self, t):
|
|
||||||
return (t * math.pi / 2).tan()
|
|
||||||
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
|
||||||
target = (input - c_skip * noised_input) / c_out
|
|
||||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
|
||||||
|
|
||||||
|
|
||||||
class DiscreteSchedule(nn.Module):
|
|
||||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
|
||||||
levels."""
|
|
||||||
|
|
||||||
def __init__(self, sigmas, quantize):
|
|
||||||
super().__init__()
|
|
||||||
self.register_buffer('sigmas', sigmas)
|
|
||||||
self.register_buffer('log_sigmas', sigmas.log())
|
|
||||||
self.quantize = quantize
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sigma_min(self):
|
|
||||||
return self.sigmas[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sigma_max(self):
|
|
||||||
return self.sigmas[-1]
|
|
||||||
|
|
||||||
def get_sigmas(self, n=None):
|
|
||||||
if n is None:
|
|
||||||
return sampling.append_zero(self.sigmas.flip(0))
|
|
||||||
t_max = len(self.sigmas) - 1
|
|
||||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
|
||||||
return sampling.append_zero(self.t_to_sigma(t))
|
|
||||||
|
|
||||||
def sigma_to_discrete_timestep(self, sigma):
|
|
||||||
log_sigma = sigma.log()
|
|
||||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
|
||||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
|
||||||
|
|
||||||
def sigma_to_t(self, sigma, quantize=None):
|
|
||||||
quantize = self.quantize if quantize is None else quantize
|
|
||||||
if quantize:
|
|
||||||
return self.sigma_to_discrete_timestep(sigma)
|
|
||||||
log_sigma = sigma.log()
|
|
||||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
|
||||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
|
||||||
high_idx = low_idx + 1
|
|
||||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
|
||||||
w = (low - log_sigma) / (low - high)
|
|
||||||
w = w.clamp(0, 1)
|
|
||||||
t = (1 - w) * low_idx + w * high_idx
|
|
||||||
return t.view(sigma.shape)
|
|
||||||
|
|
||||||
def t_to_sigma(self, t):
|
|
||||||
t = t.float()
|
|
||||||
low_idx = t.floor().long()
|
|
||||||
high_idx = t.ceil().long()
|
|
||||||
w = t-low_idx if t.device.type == 'mps' else t.frac()
|
|
||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
|
||||||
return log_sigma.exp()
|
|
||||||
|
|
||||||
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
|
||||||
if t.dtype != torch.int64 and t.dtype != torch.int32:
|
|
||||||
t = t.round()
|
|
||||||
sigma = self.t_to_sigma(t)
|
|
||||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
|
||||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
|
||||||
|
|
||||||
def predict_eps_sigma(self, input, sigma, **kwargs):
|
|
||||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
|
||||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
|
||||||
|
|
||||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
|
||||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
|
||||||
noise)."""
|
|
||||||
|
|
||||||
def __init__(self, model, alphas_cumprod, quantize):
|
|
||||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
|
||||||
self.inner_model = model
|
|
||||||
self.sigma_data = 1.
|
|
||||||
|
|
||||||
def get_scalings(self, sigma):
|
|
||||||
c_out = -sigma
|
|
||||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
return c_out, c_in
|
|
||||||
|
|
||||||
def get_eps(self, *args, **kwargs):
|
|
||||||
return self.inner_model(*args, **kwargs)
|
|
||||||
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
|
||||||
return (eps - noise).pow(2).flatten(1).mean(1)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, **kwargs):
|
|
||||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
|
||||||
return input + eps * c_out
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
|
||||||
"""A wrapper for OpenAI diffusion models."""
|
|
||||||
|
|
||||||
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
|
|
||||||
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
|
|
||||||
super().__init__(model, alphas_cumprod, quantize=quantize)
|
|
||||||
self.has_learned_sigmas = has_learned_sigmas
|
|
||||||
|
|
||||||
def get_eps(self, *args, **kwargs):
|
|
||||||
model_output = self.inner_model(*args, **kwargs)
|
|
||||||
if self.has_learned_sigmas:
|
|
||||||
return model_output.chunk(2, dim=1)[0]
|
|
||||||
return model_output
|
|
||||||
|
|
||||||
|
|
||||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
|
||||||
"""A wrapper for CompVis diffusion models."""
|
|
||||||
|
|
||||||
def __init__(self, model, quantize=False, device='cpu'):
|
|
||||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
|
||||||
|
|
||||||
def get_eps(self, *args, **kwargs):
|
|
||||||
return self.inner_model.apply_model(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscreteVDDPMDenoiser(DiscreteSchedule):
|
|
||||||
"""A wrapper for discrete schedule DDPM models that output v."""
|
|
||||||
|
|
||||||
def __init__(self, model, alphas_cumprod, quantize):
|
|
||||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
|
||||||
self.inner_model = model
|
|
||||||
self.sigma_data = 1.
|
|
||||||
|
|
||||||
def get_scalings(self, sigma):
|
|
||||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
|
||||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
return c_skip, c_out, c_in
|
|
||||||
|
|
||||||
def get_v(self, *args, **kwargs):
|
|
||||||
return self.inner_model(*args, **kwargs)
|
|
||||||
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
|
||||||
target = (input - c_skip * noised_input) / c_out
|
|
||||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
|
||||||
|
|
||||||
|
|
||||||
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
|
||||||
"""A wrapper for CompVis diffusion models that output v."""
|
|
||||||
|
|
||||||
def __init__(self, model, quantize=False, device='cpu'):
|
|
||||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
|
||||||
|
|
||||||
def get_v(self, x, t, cond, **kwargs):
|
|
||||||
return self.inner_model.apply_model(x, t, cond)
|
|
||||||
@ -1,418 +0,0 @@
|
|||||||
"""SAMPLING ONLY."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
|
||||||
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
|
||||||
self.schedule = schedule
|
|
||||||
self.device = device
|
|
||||||
self.parameterization = kwargs.get("parameterization", "eps")
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
|
||||||
if type(attr) == torch.Tensor:
|
|
||||||
if attr.device != self.device:
|
|
||||||
attr = attr.float().to(self.device)
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
|
||||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
|
||||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
|
||||||
self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose)
|
|
||||||
|
|
||||||
def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True):
|
|
||||||
self.ddim_timesteps = torch.tensor(ddim_timesteps)
|
|
||||||
alphas_cumprod = self.model.alphas_cumprod
|
|
||||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
|
||||||
|
|
||||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
|
||||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
|
||||||
|
|
||||||
# ddim sampling parameters
|
|
||||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
|
||||||
ddim_timesteps=self.ddim_timesteps,
|
|
||||||
eta=ddim_eta,verbose=verbose)
|
|
||||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
|
||||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
|
||||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
|
||||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
|
||||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
|
||||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
|
||||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_custom(self,
|
|
||||||
ddim_timesteps,
|
|
||||||
conditioning=None,
|
|
||||||
callback=None,
|
|
||||||
img_callback=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
eta=0.,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
temperature=1.,
|
|
||||||
noise_dropout=0.,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
verbose=True,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=100,
|
|
||||||
unconditional_guidance_scale=1.,
|
|
||||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
||||||
dynamic_threshold=None,
|
|
||||||
ucg_schedule=None,
|
|
||||||
denoise_function=None,
|
|
||||||
extra_args=None,
|
|
||||||
to_zero=True,
|
|
||||||
end_step=None,
|
|
||||||
disable_pbar=False,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
|
|
||||||
samples, intermediates = self.ddim_sampling(conditioning, x_T.shape,
|
|
||||||
callback=callback,
|
|
||||||
img_callback=img_callback,
|
|
||||||
quantize_denoised=quantize_x0,
|
|
||||||
mask=mask, x0=x0,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
temperature=temperature,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
log_every_t=log_every_t,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
dynamic_threshold=dynamic_threshold,
|
|
||||||
ucg_schedule=ucg_schedule,
|
|
||||||
denoise_function=denoise_function,
|
|
||||||
extra_args=extra_args,
|
|
||||||
to_zero=to_zero,
|
|
||||||
end_step=end_step,
|
|
||||||
disable_pbar=disable_pbar
|
|
||||||
)
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self,
|
|
||||||
S,
|
|
||||||
batch_size,
|
|
||||||
shape,
|
|
||||||
conditioning=None,
|
|
||||||
callback=None,
|
|
||||||
normals_sequence=None,
|
|
||||||
img_callback=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
eta=0.,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
temperature=1.,
|
|
||||||
noise_dropout=0.,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
verbose=True,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=100,
|
|
||||||
unconditional_guidance_scale=1.,
|
|
||||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
||||||
dynamic_threshold=None,
|
|
||||||
ucg_schedule=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if conditioning is not None:
|
|
||||||
if isinstance(conditioning, dict):
|
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
|
||||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
|
||||||
cbs = ctmp.shape[0]
|
|
||||||
if cbs != batch_size:
|
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
elif isinstance(conditioning, list):
|
|
||||||
for ctmp in conditioning:
|
|
||||||
if ctmp.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
if conditioning.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape
|
|
||||||
size = (batch_size, C, H, W)
|
|
||||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
|
||||||
|
|
||||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
|
||||||
callback=callback,
|
|
||||||
img_callback=img_callback,
|
|
||||||
quantize_denoised=quantize_x0,
|
|
||||||
mask=mask, x0=x0,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
temperature=temperature,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
log_every_t=log_every_t,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
dynamic_threshold=dynamic_threshold,
|
|
||||||
ucg_schedule=ucg_schedule,
|
|
||||||
denoise_function=None,
|
|
||||||
extra_args=None
|
|
||||||
)
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
|
||||||
if noise is None:
|
|
||||||
noise = torch.randn_like(x_start)
|
|
||||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
||||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def ddim_sampling(self, cond, shape,
|
|
||||||
x_T=None, ddim_use_original_steps=False,
|
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
|
||||||
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
|
|
||||||
device = self.model.alphas_cumprod.device
|
|
||||||
b = shape[0]
|
|
||||||
if x_T is None:
|
|
||||||
img = torch.randn(shape, device=device)
|
|
||||||
else:
|
|
||||||
img = x_T
|
|
||||||
|
|
||||||
if timesteps is None:
|
|
||||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
|
||||||
elif timesteps is not None and not ddim_use_original_steps:
|
|
||||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
|
||||||
timesteps = self.ddim_timesteps[:subset_end]
|
|
||||||
|
|
||||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
|
||||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0)
|
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
|
||||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
|
||||||
|
|
||||||
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
|
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
assert x0 is not None
|
|
||||||
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
|
||||||
img = img_orig * mask + (1. - mask) * img
|
|
||||||
|
|
||||||
if ucg_schedule is not None:
|
|
||||||
assert len(ucg_schedule) == len(time_range)
|
|
||||||
unconditional_guidance_scale = ucg_schedule[i]
|
|
||||||
|
|
||||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
|
||||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
|
||||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
|
|
||||||
img, pred_x0 = outs
|
|
||||||
if callback: callback(i)
|
|
||||||
if img_callback: img_callback(pred_x0, i)
|
|
||||||
|
|
||||||
if index % log_every_t == 0 or index == total_steps - 1:
|
|
||||||
intermediates['x_inter'].append(img)
|
|
||||||
intermediates['pred_x0'].append(pred_x0)
|
|
||||||
|
|
||||||
if to_zero:
|
|
||||||
img = pred_x0
|
|
||||||
else:
|
|
||||||
if ddim_use_original_steps:
|
|
||||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
|
||||||
else:
|
|
||||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
|
||||||
img /= sqrt_alphas_cumprod[index - 1]
|
|
||||||
|
|
||||||
return img, intermediates
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
|
||||||
dynamic_threshold=None, denoise_function=None, extra_args=None):
|
|
||||||
b, *_, device = *x.shape, x.device
|
|
||||||
|
|
||||||
if denoise_function is not None:
|
|
||||||
model_output = denoise_function(x, t, **extra_args)
|
|
||||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
|
||||||
model_output = self.model.apply_model(x, t, c)
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2)
|
|
||||||
if isinstance(c, dict):
|
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
|
||||||
c_in = dict()
|
|
||||||
for k in c:
|
|
||||||
if isinstance(c[k], list):
|
|
||||||
c_in[k] = [torch.cat([
|
|
||||||
unconditional_conditioning[k][i],
|
|
||||||
c[k][i]]) for i in range(len(c[k]))]
|
|
||||||
else:
|
|
||||||
c_in[k] = torch.cat([
|
|
||||||
unconditional_conditioning[k],
|
|
||||||
c[k]])
|
|
||||||
elif isinstance(c, list):
|
|
||||||
c_in = list()
|
|
||||||
assert isinstance(unconditional_conditioning, list)
|
|
||||||
for i in range(len(c)):
|
|
||||||
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
|
||||||
else:
|
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
|
||||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
||||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
|
||||||
|
|
||||||
if self.parameterization == "v":
|
|
||||||
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
|
||||||
else:
|
|
||||||
e_t = model_output
|
|
||||||
|
|
||||||
if score_corrector is not None:
|
|
||||||
assert self.parameterization == "eps", 'not implemented'
|
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
|
||||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
|
||||||
|
|
||||||
# current prediction for x_0
|
|
||||||
if self.parameterization != "v":
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
||||||
else:
|
|
||||||
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
|
||||||
|
|
||||||
if quantize_denoised:
|
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
||||||
|
|
||||||
if dynamic_threshold is not None:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
# direction pointing to x_t
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
|
||||||
if noise_dropout > 0.:
|
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
return x_prev, pred_x0
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
|
||||||
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
|
||||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
|
||||||
|
|
||||||
assert t_enc <= num_reference_steps
|
|
||||||
num_steps = t_enc
|
|
||||||
|
|
||||||
if use_original_steps:
|
|
||||||
alphas_next = self.alphas_cumprod[:num_steps]
|
|
||||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
|
||||||
else:
|
|
||||||
alphas_next = self.ddim_alphas[:num_steps]
|
|
||||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
|
||||||
|
|
||||||
x_next = x0
|
|
||||||
intermediates = []
|
|
||||||
inter_steps = []
|
|
||||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
|
||||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
|
||||||
if unconditional_guidance_scale == 1.:
|
|
||||||
noise_pred = self.model.apply_model(x_next, t, c)
|
|
||||||
else:
|
|
||||||
assert unconditional_conditioning is not None
|
|
||||||
e_t_uncond, noise_pred = torch.chunk(
|
|
||||||
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
|
||||||
torch.cat((unconditional_conditioning, c))), 2)
|
|
||||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
|
||||||
|
|
||||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
|
||||||
weighted_noise_pred = alphas_next[i].sqrt() * (
|
|
||||||
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
|
||||||
x_next = xt_weighted + weighted_noise_pred
|
|
||||||
if return_intermediates and i % (
|
|
||||||
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
|
||||||
intermediates.append(x_next)
|
|
||||||
inter_steps.append(i)
|
|
||||||
elif return_intermediates and i >= num_steps - 2:
|
|
||||||
intermediates.append(x_next)
|
|
||||||
inter_steps.append(i)
|
|
||||||
if callback: callback(i)
|
|
||||||
|
|
||||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
|
||||||
if return_intermediates:
|
|
||||||
out.update({'intermediates': intermediates})
|
|
||||||
return x_next, out
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False):
|
|
||||||
# fast, but does not allow for exact reconstruction
|
|
||||||
# t serves as an index to gather the correct alphas
|
|
||||||
if use_original_steps:
|
|
||||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
|
||||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
|
||||||
else:
|
|
||||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
|
||||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
|
||||||
|
|
||||||
if noise is None:
|
|
||||||
noise = torch.randn_like(x0)
|
|
||||||
if max_denoise:
|
|
||||||
noise_multiplier = 1.0
|
|
||||||
else:
|
|
||||||
noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
|
||||||
|
|
||||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
|
||||||
use_original_steps=False, callback=None):
|
|
||||||
|
|
||||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
|
||||||
timesteps = timesteps[:t_start]
|
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
|
||||||
total_steps = timesteps.shape[0]
|
|
||||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
|
||||||
x_dec = x_latent
|
|
||||||
for i, step in enumerate(iterator):
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
|
||||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
|
||||||
if callback: callback(i)
|
|
||||||
return x_dec
|
|
||||||
@ -1 +0,0 @@
|
|||||||
from .sampler import DPMSolverSampler
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,96 +0,0 @@
|
|||||||
"""SAMPLING ONLY."""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
|
||||||
|
|
||||||
MODEL_TYPES = {
|
|
||||||
"eps": "noise",
|
|
||||||
"v": "v"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DPMSolverSampler(object):
|
|
||||||
def __init__(self, model, device=torch.device("cuda"), **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.device = device
|
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
|
||||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
|
||||||
if type(attr) == torch.Tensor:
|
|
||||||
if attr.device != self.device:
|
|
||||||
attr = attr.to(self.device)
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self,
|
|
||||||
S,
|
|
||||||
batch_size,
|
|
||||||
shape,
|
|
||||||
conditioning=None,
|
|
||||||
callback=None,
|
|
||||||
normals_sequence=None,
|
|
||||||
img_callback=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
eta=0.,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
temperature=1.,
|
|
||||||
noise_dropout=0.,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
verbose=True,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=100,
|
|
||||||
unconditional_guidance_scale=1.,
|
|
||||||
unconditional_conditioning=None,
|
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if conditioning is not None:
|
|
||||||
if isinstance(conditioning, dict):
|
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
|
||||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
|
||||||
if isinstance(ctmp, torch.Tensor):
|
|
||||||
cbs = ctmp.shape[0]
|
|
||||||
if cbs != batch_size:
|
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
|
||||||
elif isinstance(conditioning, list):
|
|
||||||
for ctmp in conditioning:
|
|
||||||
if ctmp.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
|
||||||
else:
|
|
||||||
if isinstance(conditioning, torch.Tensor):
|
|
||||||
if conditioning.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape
|
|
||||||
size = (batch_size, C, H, W)
|
|
||||||
|
|
||||||
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
|
||||||
|
|
||||||
device = self.model.betas.device
|
|
||||||
if x_T is None:
|
|
||||||
img = torch.randn(size, device=device)
|
|
||||||
else:
|
|
||||||
img = x_T
|
|
||||||
|
|
||||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
|
||||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
|
||||||
ns,
|
|
||||||
model_type=MODEL_TYPES[self.model.parameterization],
|
|
||||||
guidance_type="classifier-free",
|
|
||||||
condition=conditioning,
|
|
||||||
unconditional_condition=unconditional_conditioning,
|
|
||||||
guidance_scale=unconditional_guidance_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
|
||||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
|
||||||
lower_order_final=True)
|
|
||||||
|
|
||||||
return x.to(device), None
|
|
||||||
@ -1,245 +0,0 @@
|
|||||||
"""SAMPLING ONLY."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
|
||||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
|
||||||
|
|
||||||
|
|
||||||
class PLMSSampler(object):
|
|
||||||
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
|
||||||
self.schedule = schedule
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
|
||||||
if type(attr) == torch.Tensor:
|
|
||||||
if attr.device != self.device:
|
|
||||||
attr = attr.to(self.device)
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
|
||||||
if ddim_eta != 0:
|
|
||||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
|
||||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
|
||||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
|
||||||
alphas_cumprod = self.model.alphas_cumprod
|
|
||||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
|
||||||
|
|
||||||
self.register_buffer('betas', to_torch(self.model.betas))
|
|
||||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
|
||||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
|
||||||
|
|
||||||
# ddim sampling parameters
|
|
||||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
|
||||||
ddim_timesteps=self.ddim_timesteps,
|
|
||||||
eta=ddim_eta,verbose=verbose)
|
|
||||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
|
||||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
|
||||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
|
||||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
|
||||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
|
||||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
|
||||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self,
|
|
||||||
S,
|
|
||||||
batch_size,
|
|
||||||
shape,
|
|
||||||
conditioning=None,
|
|
||||||
callback=None,
|
|
||||||
normals_sequence=None,
|
|
||||||
img_callback=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
eta=0.,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
temperature=1.,
|
|
||||||
noise_dropout=0.,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
verbose=True,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=100,
|
|
||||||
unconditional_guidance_scale=1.,
|
|
||||||
unconditional_conditioning=None,
|
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
||||||
dynamic_threshold=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if conditioning is not None:
|
|
||||||
if isinstance(conditioning, dict):
|
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
|
||||||
if cbs != batch_size:
|
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
|
||||||
else:
|
|
||||||
if conditioning.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape
|
|
||||||
size = (batch_size, C, H, W)
|
|
||||||
print(f'Data shape for PLMS sampling is {size}')
|
|
||||||
|
|
||||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
|
||||||
callback=callback,
|
|
||||||
img_callback=img_callback,
|
|
||||||
quantize_denoised=quantize_x0,
|
|
||||||
mask=mask, x0=x0,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
temperature=temperature,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
log_every_t=log_every_t,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
dynamic_threshold=dynamic_threshold,
|
|
||||||
)
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def plms_sampling(self, cond, shape,
|
|
||||||
x_T=None, ddim_use_original_steps=False,
|
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
|
||||||
dynamic_threshold=None):
|
|
||||||
device = self.model.betas.device
|
|
||||||
b = shape[0]
|
|
||||||
if x_T is None:
|
|
||||||
img = torch.randn(shape, device=device)
|
|
||||||
else:
|
|
||||||
img = x_T
|
|
||||||
|
|
||||||
if timesteps is None:
|
|
||||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
|
||||||
elif timesteps is not None and not ddim_use_original_steps:
|
|
||||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
|
||||||
timesteps = self.ddim_timesteps[:subset_end]
|
|
||||||
|
|
||||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
|
||||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
|
||||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
|
||||||
old_eps = []
|
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
|
||||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
assert x0 is not None
|
|
||||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
|
||||||
img = img_orig * mask + (1. - mask) * img
|
|
||||||
|
|
||||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
|
||||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
|
||||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
old_eps=old_eps, t_next=ts_next,
|
|
||||||
dynamic_threshold=dynamic_threshold)
|
|
||||||
img, pred_x0, e_t = outs
|
|
||||||
old_eps.append(e_t)
|
|
||||||
if len(old_eps) >= 4:
|
|
||||||
old_eps.pop(0)
|
|
||||||
if callback: callback(i)
|
|
||||||
if img_callback: img_callback(pred_x0, i)
|
|
||||||
|
|
||||||
if index % log_every_t == 0 or index == total_steps - 1:
|
|
||||||
intermediates['x_inter'].append(img)
|
|
||||||
intermediates['pred_x0'].append(pred_x0)
|
|
||||||
|
|
||||||
return img, intermediates
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
|
||||||
dynamic_threshold=None):
|
|
||||||
b, *_, device = *x.shape, x.device
|
|
||||||
|
|
||||||
def get_model_output(x, t):
|
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
|
||||||
e_t = self.model.apply_model(x, t, c)
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2)
|
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
|
||||||
|
|
||||||
if score_corrector is not None:
|
|
||||||
assert self.model.parameterization == "eps"
|
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
|
||||||
|
|
||||||
return e_t
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
|
||||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(e_t, index):
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
|
||||||
|
|
||||||
# current prediction for x_0
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
||||||
if quantize_denoised:
|
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
||||||
if dynamic_threshold is not None:
|
|
||||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
|
||||||
# direction pointing to x_t
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
|
||||||
if noise_dropout > 0.:
|
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
return x_prev, pred_x0
|
|
||||||
|
|
||||||
e_t = get_model_output(x, t)
|
|
||||||
if len(old_eps) == 0:
|
|
||||||
# Pseudo Improved Euler (2nd order)
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
|
||||||
e_t_next = get_model_output(x_prev, t_next)
|
|
||||||
e_t_prime = (e_t + e_t_next) / 2
|
|
||||||
elif len(old_eps) == 1:
|
|
||||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
|
||||||
elif len(old_eps) == 2:
|
|
||||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
|
||||||
elif len(old_eps) >= 3:
|
|
||||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
|
||||||
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
|
||||||
|
|
||||||
return x_prev, pred_x0, e_t
|
|
||||||
@ -1,22 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def append_dims(x, target_dims):
|
|
||||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
|
||||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
|
||||||
dims_to_append = target_dims - x.ndim
|
|
||||||
if dims_to_append < 0:
|
|
||||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
|
||||||
return x[(...,) + (None,) * dims_to_append]
|
|
||||||
|
|
||||||
|
|
||||||
def norm_thresholding(x0, value):
|
|
||||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
|
||||||
return x0 * (value / s)
|
|
||||||
|
|
||||||
|
|
||||||
def spatial_norm_thresholding(x0, value):
|
|
||||||
# b c h w
|
|
||||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
|
||||||
return x0 * (value / s)
|
|
||||||
@ -160,32 +160,19 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
|
|
||||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||||
|
|
||||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
|
||||||
|
|
||||||
kv_chunk_size_min = None
|
kv_chunk_size_min = None
|
||||||
|
kv_chunk_size = None
|
||||||
|
query_chunk_size = None
|
||||||
|
|
||||||
#not sure at all about the math here
|
for x in [4096, 2048, 1024, 512, 256]:
|
||||||
#TODO: tweak this
|
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
|
||||||
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
|
if count >= k_tokens:
|
||||||
query_chunk_size_x = 1024 * 4
|
kv_chunk_size = k_tokens
|
||||||
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
|
query_chunk_size = x
|
||||||
query_chunk_size_x = 1024 * 2
|
break
|
||||||
else:
|
|
||||||
query_chunk_size_x = 1024
|
|
||||||
kv_chunk_size_min_x = None
|
|
||||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
|
||||||
if kv_chunk_size_x < 1024:
|
|
||||||
kv_chunk_size_x = None
|
|
||||||
|
|
||||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
if query_chunk_size is None:
|
||||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
query_chunk_size = 512
|
||||||
# i.e. send it down the unchunked fast-path
|
|
||||||
query_chunk_size = q_tokens
|
|
||||||
kv_chunk_size = k_tokens
|
|
||||||
else:
|
|
||||||
query_chunk_size = query_chunk_size_x
|
|
||||||
kv_chunk_size = kv_chunk_size_x
|
|
||||||
kv_chunk_size_min = kv_chunk_size_min_x
|
|
||||||
|
|
||||||
hidden_states = efficient_dot_product_attention(
|
hidden_states = efficient_dot_product_attention(
|
||||||
query,
|
query,
|
||||||
@ -222,9 +209,14 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
|
if _ATTN_PRECISION =="fp32":
|
||||||
|
element_size = 4
|
||||||
|
else:
|
||||||
|
element_size = q.element_size()
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
modifier = 3
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
@ -252,10 +244,10 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||||
else:
|
else:
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||||
first_op_done = True
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
first_op_done = True
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
del s2
|
del s2
|
||||||
|
|||||||
@ -259,10 +259,6 @@ class UNetModel(nn.Module):
|
|||||||
:param model_channels: base channel count for the model.
|
:param model_channels: base channel count for the model.
|
||||||
:param out_channels: channels in the output Tensor.
|
:param out_channels: channels in the output Tensor.
|
||||||
:param num_res_blocks: number of residual blocks per downsample.
|
:param num_res_blocks: number of residual blocks per downsample.
|
||||||
:param attention_resolutions: a collection of downsample rates at which
|
|
||||||
attention will take place. May be a set, list, or tuple.
|
|
||||||
For example, if this contains 4, then at 4x downsampling, attention
|
|
||||||
will be used.
|
|
||||||
:param dropout: the dropout probability.
|
:param dropout: the dropout probability.
|
||||||
:param channel_mult: channel multiplier for each level of the UNet.
|
:param channel_mult: channel multiplier for each level of the UNet.
|
||||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||||
@ -289,7 +285,6 @@ class UNetModel(nn.Module):
|
|||||||
model_channels,
|
model_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
num_res_blocks,
|
num_res_blocks,
|
||||||
attention_resolutions,
|
|
||||||
dropout=0,
|
dropout=0,
|
||||||
channel_mult=(1, 2, 4, 8),
|
channel_mult=(1, 2, 4, 8),
|
||||||
conv_resample=True,
|
conv_resample=True,
|
||||||
@ -314,6 +309,7 @@ class UNetModel(nn.Module):
|
|||||||
use_linear_in_transformer=False,
|
use_linear_in_transformer=False,
|
||||||
adm_in_channels=None,
|
adm_in_channels=None,
|
||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
|
transformer_depth_output=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops,
|
operations=comfy.ops,
|
||||||
):
|
):
|
||||||
@ -341,10 +337,7 @@ class UNetModel(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
if isinstance(transformer_depth, int):
|
|
||||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
|
||||||
if transformer_depth_middle is None:
|
|
||||||
transformer_depth_middle = transformer_depth[-1]
|
|
||||||
if isinstance(num_res_blocks, int):
|
if isinstance(num_res_blocks, int):
|
||||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||||
else:
|
else:
|
||||||
@ -352,18 +345,16 @@ class UNetModel(nn.Module):
|
|||||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
|
|
||||||
if disable_self_attentions is not None:
|
if disable_self_attentions is not None:
|
||||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||||
assert len(disable_self_attentions) == len(channel_mult)
|
assert len(disable_self_attentions) == len(channel_mult)
|
||||||
if num_attention_blocks is not None:
|
if num_attention_blocks is not None:
|
||||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
|
||||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
|
||||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
|
||||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
|
||||||
f"attention will still not be set.")
|
|
||||||
|
|
||||||
self.attention_resolutions = attention_resolutions
|
transformer_depth = transformer_depth[:]
|
||||||
|
transformer_depth_output = transformer_depth_output[:]
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.channel_mult = channel_mult
|
self.channel_mult = channel_mult
|
||||||
self.conv_resample = conv_resample
|
self.conv_resample = conv_resample
|
||||||
@ -428,7 +419,8 @@ class UNetModel(nn.Module):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
if ds in attention_resolutions:
|
num_transformers = transformer_depth.pop(0)
|
||||||
|
if num_transformers > 0:
|
||||||
if num_head_channels == -1:
|
if num_head_channels == -1:
|
||||||
dim_head = ch // num_heads
|
dim_head = ch // num_heads
|
||||||
else:
|
else:
|
||||||
@ -444,7 +436,7 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(SpatialTransformer(
|
layers.append(SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
@ -488,7 +480,7 @@ class UNetModel(nn.Module):
|
|||||||
if legacy:
|
if legacy:
|
||||||
#num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
self.middle_block = TimestepEmbedSequential(
|
mid_block = [
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
@ -499,8 +491,9 @@ class UNetModel(nn.Module):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
),
|
)]
|
||||||
SpatialTransformer( # always uses a self-attn
|
if transformer_depth_middle >= 0:
|
||||||
|
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||||
@ -515,8 +508,8 @@ class UNetModel(nn.Module):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
),
|
)]
|
||||||
)
|
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
@ -538,7 +531,8 @@ class UNetModel(nn.Module):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = model_channels * mult
|
||||||
if ds in attention_resolutions:
|
num_transformers = transformer_depth_output.pop()
|
||||||
|
if num_transformers > 0:
|
||||||
if num_head_channels == -1:
|
if num_head_channels == -1:
|
||||||
dim_head = ch // num_heads
|
dim_head = ch // num_heads
|
||||||
else:
|
else:
|
||||||
@ -555,7 +549,7 @@ class UNetModel(nn.Module):
|
|||||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
SpatialTransformer(
|
SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|||||||
@ -83,7 +83,8 @@ def _summarize_chunk(
|
|||||||
)
|
)
|
||||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||||
max_score = max_score.detach()
|
max_score = max_score.detach()
|
||||||
torch.exp(attn_weights - max_score, out=attn_weights)
|
attn_weights -= max_score
|
||||||
|
torch.exp(attn_weights, out=attn_weights)
|
||||||
exp_weights = attn_weights.to(value.dtype)
|
exp_weights = attn_weights.to(value.dtype)
|
||||||
exp_values = torch.bmm(exp_weights, value)
|
exp_values = torch.bmm(exp_weights, value)
|
||||||
max_score = max_score.squeeze(-1)
|
max_score = max_score.squeeze(-1)
|
||||||
|
|||||||
@ -141,9 +141,9 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
for b in range(32):
|
for b in range(32): #TODO: clean up
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
@ -154,6 +154,8 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
|
|
||||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
clip_l_present = True
|
clip_l_present = True
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
|||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.conds
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
@ -12,6 +13,96 @@ class ModelType(Enum):
|
|||||||
EPS = 1
|
EPS = 1
|
||||||
V_PREDICTION = 2
|
V_PREDICTION = 2
|
||||||
|
|
||||||
|
|
||||||
|
#NOTE: all this sampling stuff will be moved
|
||||||
|
class EPS:
|
||||||
|
def calculate_input(self, sigma, noise):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||||
|
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input - model_output * sigma
|
||||||
|
|
||||||
|
|
||||||
|
class V_PREDICTION(EPS):
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
|
def __init__(self, model_config=None):
|
||||||
|
super().__init__()
|
||||||
|
beta_schedule = "linear"
|
||||||
|
if model_config is not None:
|
||||||
|
beta_schedule = model_config.beta_schedule
|
||||||
|
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
|
self.sigma_data = 1.0
|
||||||
|
|
||||||
|
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
|
if given_betas is not None:
|
||||||
|
betas = given_betas
|
||||||
|
else:
|
||||||
|
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||||
|
alphas = 1. - betas
|
||||||
|
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
||||||
|
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||||
|
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.linear_start = linear_start
|
||||||
|
self.linear_end = linear_end
|
||||||
|
|
||||||
|
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||||
|
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||||
|
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
|
|
||||||
|
self.register_buffer('sigmas', sigmas)
|
||||||
|
self.register_buffer('log_sigmas', sigmas.log())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
log_sigma = sigma.log()
|
||||||
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
|
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
|
||||||
|
low_idx = t.floor().long()
|
||||||
|
high_idx = t.ceil().long()
|
||||||
|
w = t.frac()
|
||||||
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
|
return log_sigma.exp()
|
||||||
|
|
||||||
|
def percent_to_sigma(self, percent):
|
||||||
|
return self.sigma(torch.tensor(percent * 999.0))
|
||||||
|
|
||||||
|
def model_sampling(model_config, model_type):
|
||||||
|
if model_type == ModelType.EPS:
|
||||||
|
c = EPS
|
||||||
|
elif model_type == ModelType.V_PREDICTION:
|
||||||
|
c = V_PREDICTION
|
||||||
|
|
||||||
|
s = ModelSamplingDiscrete
|
||||||
|
|
||||||
|
class ModelSampling(s, c):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return ModelSampling(model_config)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -19,10 +110,12 @@ class BaseModel(torch.nn.Module):
|
|||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
self.latent_format = model_config.latent_format
|
self.latent_format = model_config.latent_format
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
@ -30,38 +123,22 @@ class BaseModel(torch.nn.Module):
|
|||||||
print("model_type", model_type.name)
|
print("model_type", model_type.name)
|
||||||
print("adm", self.adm_channels)
|
print("adm", self.adm_channels)
|
||||||
|
|
||||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
sigma = t
|
||||||
if given_betas is not None:
|
xc = self.model_sampling.calculate_input(sigma, x)
|
||||||
betas = given_betas
|
|
||||||
else:
|
|
||||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
|
||||||
alphas = 1. - betas
|
|
||||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
||||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
|
||||||
|
|
||||||
timesteps, = betas.shape
|
|
||||||
self.num_timesteps = int(timesteps)
|
|
||||||
self.linear_start = linear_start
|
|
||||||
self.linear_end = linear_end
|
|
||||||
|
|
||||||
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
|
||||||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
|
||||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
|
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
xc = torch.cat([x] + [c_concat], dim=1)
|
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||||
else:
|
|
||||||
xc = x
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = t.to(dtype)
|
t = self.model_sampling.timestep(t).float()
|
||||||
context = context.to(dtype)
|
context = context.to(dtype)
|
||||||
if c_adm is not None:
|
extra_conds = {}
|
||||||
c_adm = c_adm.to(dtype)
|
for o in kwargs:
|
||||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
|
extra_conds[o] = kwargs[o].to(dtype)
|
||||||
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||||
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||||
|
|
||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
@ -72,7 +149,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def cond_concat(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
if self.inpaint_model:
|
if self.inpaint_model:
|
||||||
concat_keys = ("mask", "masked_image")
|
concat_keys = ("mask", "masked_image")
|
||||||
cond_concat = []
|
cond_concat = []
|
||||||
@ -101,8 +179,12 @@ class BaseModel(torch.nn.Module):
|
|||||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||||
elif ck == "masked_image":
|
elif ck == "masked_image":
|
||||||
cond_concat.append(blank_inpaint_image_like(noise))
|
cond_concat.append(blank_inpaint_image_like(noise))
|
||||||
return cond_concat
|
data = torch.cat(cond_concat, dim=1)
|
||||||
return None
|
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||||
|
adm = self.encode_adm(**kwargs)
|
||||||
|
if adm is not None:
|
||||||
|
out['y'] = comfy.conds.CONDRegular(adm)
|
||||||
|
return out
|
||||||
|
|
||||||
def load_model_weights(self, sd, unet_prefix=""):
|
def load_model_weights(self, sd, unet_prefix=""):
|
||||||
to_load = {}
|
to_load = {}
|
||||||
|
|||||||
@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||||
|
context_dim = None
|
||||||
|
use_linear_in_transformer = False
|
||||||
|
|
||||||
|
transformer_prefix = prefix + "1.transformer_blocks."
|
||||||
|
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
||||||
|
if len(transformer_keys) > 0:
|
||||||
|
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||||
|
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||||
|
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||||
|
return last_transformer_depth, context_dim, use_linear_in_transformer
|
||||||
|
return None
|
||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
channel_mult = []
|
channel_mult = []
|
||||||
attention_resolutions = []
|
attention_resolutions = []
|
||||||
transformer_depth = []
|
transformer_depth = []
|
||||||
|
transformer_depth_output = []
|
||||||
context_dim = None
|
context_dim = None
|
||||||
use_linear_in_transformer = False
|
use_linear_in_transformer = False
|
||||||
|
|
||||||
@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
last_res_blocks = 0
|
last_res_blocks = 0
|
||||||
last_transformer_depth = 0
|
|
||||||
last_channel_mult = 0
|
last_channel_mult = 0
|
||||||
|
|
||||||
while True:
|
input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.')
|
||||||
|
for count in range(input_block_count):
|
||||||
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
|
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
|
||||||
|
prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1)
|
||||||
|
|
||||||
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
|
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
|
||||||
if len(block_keys) == 0:
|
if len(block_keys) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
|
||||||
|
|
||||||
if "{}0.op.weight".format(prefix) in block_keys: #new layer
|
if "{}0.op.weight".format(prefix) in block_keys: #new layer
|
||||||
if last_transformer_depth > 0:
|
|
||||||
attention_resolutions.append(current_res)
|
|
||||||
transformer_depth.append(last_transformer_depth)
|
|
||||||
num_res_blocks.append(last_res_blocks)
|
num_res_blocks.append(last_res_blocks)
|
||||||
channel_mult.append(last_channel_mult)
|
channel_mult.append(last_channel_mult)
|
||||||
|
|
||||||
current_res *= 2
|
current_res *= 2
|
||||||
last_res_blocks = 0
|
last_res_blocks = 0
|
||||||
last_transformer_depth = 0
|
|
||||||
last_channel_mult = 0
|
last_channel_mult = 0
|
||||||
|
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
||||||
|
if out is not None:
|
||||||
|
transformer_depth_output.append(out[0])
|
||||||
|
else:
|
||||||
|
transformer_depth_output.append(0)
|
||||||
else:
|
else:
|
||||||
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
|
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
|
||||||
if res_block_prefix in block_keys:
|
if res_block_prefix in block_keys:
|
||||||
last_res_blocks += 1
|
last_res_blocks += 1
|
||||||
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
|
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
|
||||||
|
|
||||||
transformer_prefix = prefix + "1.transformer_blocks."
|
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
|
||||||
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
if out is not None:
|
||||||
if len(transformer_keys) > 0:
|
transformer_depth.append(out[0])
|
||||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
if context_dim is None:
|
||||||
if context_dim is None:
|
context_dim = out[1]
|
||||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
use_linear_in_transformer = out[2]
|
||||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
else:
|
||||||
|
transformer_depth.append(0)
|
||||||
|
|
||||||
|
res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
|
||||||
|
if res_block_prefix in block_keys_output:
|
||||||
|
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
||||||
|
if out is not None:
|
||||||
|
transformer_depth_output.append(out[0])
|
||||||
|
else:
|
||||||
|
transformer_depth_output.append(0)
|
||||||
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
if last_transformer_depth > 0:
|
|
||||||
attention_resolutions.append(current_res)
|
|
||||||
transformer_depth.append(last_transformer_depth)
|
|
||||||
num_res_blocks.append(last_res_blocks)
|
num_res_blocks.append(last_res_blocks)
|
||||||
channel_mult.append(last_channel_mult)
|
channel_mult.append(last_channel_mult)
|
||||||
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
|
||||||
|
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
||||||
if len(set(num_res_blocks)) == 1:
|
else:
|
||||||
num_res_blocks = num_res_blocks[0]
|
transformer_depth_middle = -1
|
||||||
|
|
||||||
if len(set(transformer_depth)) == 1:
|
|
||||||
transformer_depth = transformer_depth[0]
|
|
||||||
|
|
||||||
unet_config["in_channels"] = in_channels
|
unet_config["in_channels"] = in_channels
|
||||||
unet_config["model_channels"] = model_channels
|
unet_config["model_channels"] = model_channels
|
||||||
unet_config["num_res_blocks"] = num_res_blocks
|
unet_config["num_res_blocks"] = num_res_blocks
|
||||||
unet_config["attention_resolutions"] = attention_resolutions
|
|
||||||
unet_config["transformer_depth"] = transformer_depth
|
unet_config["transformer_depth"] = transformer_depth
|
||||||
|
unet_config["transformer_depth_output"] = transformer_depth_output
|
||||||
unet_config["channel_mult"] = channel_mult
|
unet_config["channel_mult"] = channel_mult
|
||||||
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
||||||
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
|
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
|
||||||
@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
|
|||||||
else:
|
else:
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
def convert_config(unet_config):
|
||||||
|
new_config = unet_config.copy()
|
||||||
|
num_res_blocks = new_config.get("num_res_blocks", None)
|
||||||
|
channel_mult = new_config.get("channel_mult", None)
|
||||||
|
|
||||||
|
if isinstance(num_res_blocks, int):
|
||||||
|
num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||||
|
|
||||||
|
if "attention_resolutions" in new_config:
|
||||||
|
attention_resolutions = new_config.pop("attention_resolutions")
|
||||||
|
transformer_depth = new_config.get("transformer_depth", None)
|
||||||
|
transformer_depth_middle = new_config.get("transformer_depth_middle", None)
|
||||||
|
|
||||||
|
if isinstance(transformer_depth, int):
|
||||||
|
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||||
|
if transformer_depth_middle is None:
|
||||||
|
transformer_depth_middle = transformer_depth[-1]
|
||||||
|
t_in = []
|
||||||
|
t_out = []
|
||||||
|
s = 1
|
||||||
|
for i in range(len(num_res_blocks)):
|
||||||
|
res = num_res_blocks[i]
|
||||||
|
d = 0
|
||||||
|
if s in attention_resolutions:
|
||||||
|
d = transformer_depth[i]
|
||||||
|
|
||||||
|
t_in += [d] * res
|
||||||
|
t_out += [d] * (res + 1)
|
||||||
|
s *= 2
|
||||||
|
transformer_depth = t_in
|
||||||
|
transformer_depth_output = t_out
|
||||||
|
new_config["transformer_depth"] = t_in
|
||||||
|
new_config["transformer_depth_output"] = t_out
|
||||||
|
new_config["transformer_depth_middle"] = transformer_depth_middle
|
||||||
|
|
||||||
|
new_config["num_res_blocks"] = num_res_blocks
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||||
match = {}
|
match = {}
|
||||||
attention_resolutions = []
|
attention_resolutions = []
|
||||||
@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
|||||||
matches = False
|
matches = False
|
||||||
break
|
break
|
||||||
if matches:
|
if matches:
|
||||||
return unet_config
|
return convert_config(unet_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
|
import comfy.conds
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device):
|
|||||||
noise_mask = noise_mask.to(device)
|
noise_mask = noise_mask.to(device)
|
||||||
return noise_mask
|
return noise_mask
|
||||||
|
|
||||||
def broadcast_cond(cond, batch, device):
|
|
||||||
"""broadcasts conditioning to the batch size"""
|
|
||||||
copy = []
|
|
||||||
for p in cond:
|
|
||||||
t = comfy.utils.repeat_to_batch_size(p[0], batch)
|
|
||||||
t = t.to(device)
|
|
||||||
copy += [[t] + p[1:]]
|
|
||||||
return copy
|
|
||||||
|
|
||||||
def get_models_from_cond(cond, model_type):
|
def get_models_from_cond(cond, model_type):
|
||||||
models = []
|
models = []
|
||||||
for c in cond:
|
for c in cond:
|
||||||
if model_type in c[1]:
|
if model_type in c:
|
||||||
models += [c[1][model_type]]
|
models += [c[model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
def convert_cond(cond):
|
||||||
|
out = []
|
||||||
|
for c in cond:
|
||||||
|
temp = c[1].copy()
|
||||||
|
model_conds = temp.get("model_conds", {})
|
||||||
|
if c[0] is not None:
|
||||||
|
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0])
|
||||||
|
temp["model_conds"] = model_conds
|
||||||
|
out.append(temp)
|
||||||
|
return out
|
||||||
|
|
||||||
def get_additional_models(positive, negative, dtype):
|
def get_additional_models(positive, negative, dtype):
|
||||||
"""loads additional models in positive and negative conditioning"""
|
"""loads additional models in positive and negative conditioning"""
|
||||||
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
||||||
@ -72,6 +75,8 @@ def cleanup_additional_models(models):
|
|||||||
|
|
||||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||||
device = model.load_device
|
device = model.load_device
|
||||||
|
positive = convert_cond(positive)
|
||||||
|
negative = convert_cond(negative)
|
||||||
|
|
||||||
if noise_mask is not None:
|
if noise_mask is not None:
|
||||||
noise_mask = prepare_mask(noise_mask, noise_shape, device)
|
noise_mask = prepare_mask(noise_mask, noise_shape, device)
|
||||||
@ -81,9 +86,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
|||||||
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
|
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
positive_copy = broadcast_cond(positive, noise_shape[0], device)
|
return real_model, positive, negative, noise_mask, models
|
||||||
negative_copy = broadcast_cond(negative, noise_shape[0], device)
|
|
||||||
return real_model, positive_copy, negative_copy, noise_mask, models
|
|
||||||
|
|
||||||
|
|
||||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||||
|
|||||||
@ -1,48 +1,42 @@
|
|||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .k_diffusion import external as k_diffusion_external
|
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
import torch
|
import torch
|
||||||
|
import enum
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
|
||||||
import math
|
import math
|
||||||
from comfy import model_base
|
from comfy import model_base
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.conds
|
||||||
|
|
||||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
|
||||||
return abs(a*b) // math.gcd(a, b)
|
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns denoised
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
def get_area_and_mult(cond, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
if 'timestep_start' in cond[1]:
|
|
||||||
timestep_start = cond[1]['timestep_start']
|
if 'timestep_start' in conds:
|
||||||
|
timestep_start = conds['timestep_start']
|
||||||
if timestep_in[0] > timestep_start:
|
if timestep_in[0] > timestep_start:
|
||||||
return None
|
return None
|
||||||
if 'timestep_end' in cond[1]:
|
if 'timestep_end' in conds:
|
||||||
timestep_end = cond[1]['timestep_end']
|
timestep_end = conds['timestep_end']
|
||||||
if timestep_in[0] < timestep_end:
|
if timestep_in[0] < timestep_end:
|
||||||
return None
|
return None
|
||||||
if 'area' in cond[1]:
|
if 'area' in conds:
|
||||||
area = cond[1]['area']
|
area = conds['area']
|
||||||
if 'strength' in cond[1]:
|
if 'strength' in conds:
|
||||||
strength = cond[1]['strength']
|
strength = conds['strength']
|
||||||
|
|
||||||
adm_cond = None
|
|
||||||
if 'adm_encoded' in cond[1]:
|
|
||||||
adm_cond = cond[1]['adm_encoded']
|
|
||||||
|
|
||||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
if 'mask' in cond[1]:
|
if 'mask' in conds:
|
||||||
# Scale the mask to the size of the input
|
# Scale the mask to the size of the input
|
||||||
# The mask should have been resized as we began the sampling process
|
# The mask should have been resized as we began the sampling process
|
||||||
mask_strength = 1.0
|
mask_strength = 1.0
|
||||||
if "mask_strength" in cond[1]:
|
if "mask_strength" in conds:
|
||||||
mask_strength = cond[1]["mask_strength"]
|
mask_strength = conds["mask_strength"]
|
||||||
mask = cond[1]['mask']
|
mask = conds['mask']
|
||||||
assert(mask.shape[1] == x_in.shape[2])
|
assert(mask.shape[1] == x_in.shape[2])
|
||||||
assert(mask.shape[2] == x_in.shape[3])
|
assert(mask.shape[2] == x_in.shape[3])
|
||||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||||
@ -51,7 +45,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
mask = torch.ones_like(input_x)
|
mask = torch.ones_like(input_x)
|
||||||
mult = mask * strength
|
mult = mask * strength
|
||||||
|
|
||||||
if 'mask' not in cond[1]:
|
if 'mask' not in conds:
|
||||||
rr = 8
|
rr = 8
|
||||||
if area[2] != 0:
|
if area[2] != 0:
|
||||||
for t in range(rr):
|
for t in range(rr):
|
||||||
@ -67,27 +61,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||||
|
|
||||||
conditionning = {}
|
conditionning = {}
|
||||||
conditionning['c_crossattn'] = cond[0]
|
model_conds = conds["model_conds"]
|
||||||
|
for c in model_conds:
|
||||||
if 'concat' in cond[1]:
|
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||||
cond_concat_in = cond[1]['concat']
|
|
||||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
|
||||||
cropped = []
|
|
||||||
for x in cond_concat_in:
|
|
||||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
|
||||||
cropped.append(cr)
|
|
||||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
|
||||||
|
|
||||||
if adm_cond is not None:
|
|
||||||
conditionning['c_adm'] = adm_cond
|
|
||||||
|
|
||||||
control = None
|
control = None
|
||||||
if 'control' in cond[1]:
|
if 'control' in conds:
|
||||||
control = cond[1]['control']
|
control = conds['control']
|
||||||
|
|
||||||
patches = None
|
patches = None
|
||||||
if 'gligen' in cond[1]:
|
if 'gligen' in conds:
|
||||||
gligen = cond[1]['gligen']
|
gligen = conds['gligen']
|
||||||
patches = {}
|
patches = {}
|
||||||
gligen_type = gligen[0]
|
gligen_type = gligen[0]
|
||||||
gligen_model = gligen[1]
|
gligen_model = gligen[1]
|
||||||
@ -105,22 +89,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
return True
|
return True
|
||||||
if c1.keys() != c2.keys():
|
if c1.keys() != c2.keys():
|
||||||
return False
|
return False
|
||||||
if 'c_crossattn' in c1:
|
for k in c1:
|
||||||
s1 = c1['c_crossattn'].shape
|
if not c1[k].can_concat(c2[k]):
|
||||||
s2 = c2['c_crossattn'].shape
|
|
||||||
if s1 != s2:
|
|
||||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
|
||||||
return False
|
|
||||||
|
|
||||||
mult_min = lcm(s1[1], s2[1])
|
|
||||||
diff = mult_min // min(s1[1], s2[1])
|
|
||||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
|
||||||
return False
|
|
||||||
if 'c_concat' in c1:
|
|
||||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
|
||||||
return False
|
|
||||||
if 'c_adm' in c1:
|
|
||||||
if c1['c_adm'].shape != c2['c_adm'].shape:
|
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -149,39 +119,27 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
c_concat = []
|
c_concat = []
|
||||||
c_adm = []
|
c_adm = []
|
||||||
crossattn_max_len = 0
|
crossattn_max_len = 0
|
||||||
for x in c_list:
|
|
||||||
if 'c_crossattn' in x:
|
|
||||||
c = x['c_crossattn']
|
|
||||||
if crossattn_max_len == 0:
|
|
||||||
crossattn_max_len = c.shape[1]
|
|
||||||
else:
|
|
||||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
|
||||||
c_crossattn.append(c)
|
|
||||||
if 'c_concat' in x:
|
|
||||||
c_concat.append(x['c_concat'])
|
|
||||||
if 'c_adm' in x:
|
|
||||||
c_adm.append(x['c_adm'])
|
|
||||||
out = {}
|
|
||||||
c_crossattn_out = []
|
|
||||||
for c in c_crossattn:
|
|
||||||
if c.shape[1] < crossattn_max_len:
|
|
||||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
|
||||||
c_crossattn_out.append(c)
|
|
||||||
|
|
||||||
if len(c_crossattn_out) > 0:
|
temp = {}
|
||||||
out['c_crossattn'] = torch.cat(c_crossattn_out)
|
for x in c_list:
|
||||||
if len(c_concat) > 0:
|
for k in x:
|
||||||
out['c_concat'] = torch.cat(c_concat)
|
cur = temp.get(k, [])
|
||||||
if len(c_adm) > 0:
|
cur.append(x[k])
|
||||||
out['c_adm'] = torch.cat(c_adm)
|
temp[k] = cur
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
for k in temp:
|
||||||
|
conds = temp[k]
|
||||||
|
out[k] = conds[0].concat(conds[1:])
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
out_uncond = torch.zeros_like(x_in)
|
out_uncond = torch.zeros_like(x_in)
|
||||||
out_uncond_count = torch.ones_like(x_in)/100000.0
|
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
COND = 0
|
COND = 0
|
||||||
UNCOND = 1
|
UNCOND = 1
|
||||||
@ -281,7 +239,6 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
del out_count
|
del out_count
|
||||||
out_uncond /= out_uncond_count
|
out_uncond /= out_uncond_count
|
||||||
del out_uncond_count
|
del out_uncond_count
|
||||||
|
|
||||||
return out_cond, out_uncond
|
return out_cond, out_uncond
|
||||||
|
|
||||||
|
|
||||||
@ -291,29 +248,20 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
|
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x}
|
||||||
return model_options["sampler_cfg_function"](args)
|
return x - model_options["sampler_cfg_function"](args)
|
||||||
else:
|
else:
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
|
|
||||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
|
||||||
def __init__(self, model, quantize=False, device='cpu'):
|
|
||||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
|
||||||
|
|
||||||
def get_v(self, x, t, cond, **kwargs):
|
|
||||||
return self.inner_model.apply_model(x, t, cond, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class CFGNoisePredictor(torch.nn.Module):
|
class CFGNoisePredictor(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.alphas_cumprod = model.alphas_cumprod
|
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||||
return out
|
return out
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.apply_model(*args, **kwargs)
|
||||||
|
|
||||||
class KSamplerX0Inpaint(torch.nn.Module):
|
class KSamplerX0Inpaint(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
@ -332,32 +280,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def simple_scheduler(model, steps):
|
def simple_scheduler(model, steps):
|
||||||
|
s = model.model_sampling
|
||||||
sigs = []
|
sigs = []
|
||||||
ss = len(model.sigmas) / steps
|
ss = len(s.sigmas) / steps
|
||||||
for x in range(steps):
|
for x in range(steps):
|
||||||
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
|
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def ddim_scheduler(model, steps):
|
def ddim_scheduler(model, steps):
|
||||||
|
s = model.model_sampling
|
||||||
sigs = []
|
sigs = []
|
||||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
|
ss = len(s.sigmas) // steps
|
||||||
for x in range(len(ddim_timesteps) - 1, -1, -1):
|
x = 1
|
||||||
ts = ddim_timesteps[x]
|
while x < len(s.sigmas):
|
||||||
if ts > 999:
|
sigs += [float(s.sigmas[x])]
|
||||||
ts = 999
|
x += ss
|
||||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
sigs = sigs[::-1]
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def sgm_scheduler(model, steps):
|
def normal_scheduler(model, steps, sgm=False, floor=False):
|
||||||
|
s = model.model_sampling
|
||||||
|
start = s.timestep(s.sigma_max)
|
||||||
|
end = s.timestep(s.sigma_min)
|
||||||
|
|
||||||
|
if sgm:
|
||||||
|
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
||||||
|
else:
|
||||||
|
timesteps = torch.linspace(start, end, steps)
|
||||||
|
|
||||||
sigs = []
|
sigs = []
|
||||||
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
|
|
||||||
for x in range(len(timesteps)):
|
for x in range(len(timesteps)):
|
||||||
ts = timesteps[x]
|
ts = timesteps[x]
|
||||||
if ts > 999:
|
sigs.append(s.sigma(ts))
|
||||||
ts = 999
|
|
||||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
@ -389,19 +345,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||||||
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
||||||
for i in range(len(conditions)):
|
for i in range(len(conditions)):
|
||||||
c = conditions[i]
|
c = conditions[i]
|
||||||
if 'area' in c[1]:
|
if 'area' in c:
|
||||||
area = c[1]['area']
|
area = c['area']
|
||||||
if area[0] == "percentage":
|
if area[0] == "percentage":
|
||||||
modified = c[1].copy()
|
modified = c.copy()
|
||||||
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
|
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
|
||||||
modified['area'] = area
|
modified['area'] = area
|
||||||
c = [c[0], modified]
|
c = modified
|
||||||
conditions[i] = c
|
conditions[i] = c
|
||||||
|
|
||||||
if 'mask' in c[1]:
|
if 'mask' in c:
|
||||||
mask = c[1]['mask']
|
mask = c['mask']
|
||||||
mask = mask.to(device=device)
|
mask = mask.to(device=device)
|
||||||
modified = c[1].copy()
|
modified = c.copy()
|
||||||
if len(mask.shape) == 2:
|
if len(mask.shape) == 2:
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
if mask.shape[1] != h or mask.shape[2] != w:
|
if mask.shape[1] != h or mask.shape[2] != w:
|
||||||
@ -422,66 +378,70 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||||||
modified['area'] = area
|
modified['area'] = area
|
||||||
|
|
||||||
modified['mask'] = mask
|
modified['mask'] = mask
|
||||||
conditions[i] = [c[0], modified]
|
conditions[i] = modified
|
||||||
|
|
||||||
def create_cond_with_same_area_if_none(conds, c):
|
def create_cond_with_same_area_if_none(conds, c):
|
||||||
if 'area' not in c[1]:
|
if 'area' not in c:
|
||||||
return
|
return
|
||||||
|
|
||||||
c_area = c[1]['area']
|
c_area = c['area']
|
||||||
smallest = None
|
smallest = None
|
||||||
for x in conds:
|
for x in conds:
|
||||||
if 'area' in x[1]:
|
if 'area' in x:
|
||||||
a = x[1]['area']
|
a = x['area']
|
||||||
if c_area[2] >= a[2] and c_area[3] >= a[3]:
|
if c_area[2] >= a[2] and c_area[3] >= a[3]:
|
||||||
if a[0] + a[2] >= c_area[0] + c_area[2]:
|
if a[0] + a[2] >= c_area[0] + c_area[2]:
|
||||||
if a[1] + a[3] >= c_area[1] + c_area[3]:
|
if a[1] + a[3] >= c_area[1] + c_area[3]:
|
||||||
if smallest is None:
|
if smallest is None:
|
||||||
smallest = x
|
smallest = x
|
||||||
elif 'area' not in smallest[1]:
|
elif 'area' not in smallest:
|
||||||
smallest = x
|
smallest = x
|
||||||
else:
|
else:
|
||||||
if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]:
|
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
|
||||||
smallest = x
|
smallest = x
|
||||||
else:
|
else:
|
||||||
if smallest is None:
|
if smallest is None:
|
||||||
smallest = x
|
smallest = x
|
||||||
if smallest is None:
|
if smallest is None:
|
||||||
return
|
return
|
||||||
if 'area' in smallest[1]:
|
if 'area' in smallest:
|
||||||
if smallest[1]['area'] == c_area:
|
if smallest['area'] == c_area:
|
||||||
return
|
return
|
||||||
n = c[1].copy()
|
|
||||||
conds += [[smallest[0], n]]
|
out = c.copy()
|
||||||
|
out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
|
||||||
|
conds += [out]
|
||||||
|
|
||||||
def calculate_start_end_timesteps(model, conds):
|
def calculate_start_end_timesteps(model, conds):
|
||||||
|
s = model.model_sampling
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
|
||||||
timestep_start = None
|
timestep_start = None
|
||||||
timestep_end = None
|
timestep_end = None
|
||||||
if 'start_percent' in x[1]:
|
if 'start_percent' in x:
|
||||||
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
|
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||||
if 'end_percent' in x[1]:
|
if 'end_percent' in x:
|
||||||
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
|
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||||
|
|
||||||
if (timestep_start is not None) or (timestep_end is not None):
|
if (timestep_start is not None) or (timestep_end is not None):
|
||||||
n = x[1].copy()
|
n = x.copy()
|
||||||
if (timestep_start is not None):
|
if (timestep_start is not None):
|
||||||
n['timestep_start'] = timestep_start
|
n['timestep_start'] = timestep_start
|
||||||
if (timestep_end is not None):
|
if (timestep_end is not None):
|
||||||
n['timestep_end'] = timestep_end
|
n['timestep_end'] = timestep_end
|
||||||
conds[t] = [x[0], n]
|
conds[t] = n
|
||||||
|
|
||||||
def pre_run_control(model, conds):
|
def pre_run_control(model, conds):
|
||||||
|
s = model.model_sampling
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
|
||||||
timestep_start = None
|
timestep_start = None
|
||||||
timestep_end = None
|
timestep_end = None
|
||||||
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||||
if 'control' in x[1]:
|
if 'control' in x:
|
||||||
x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
|
x['control'].pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
@ -490,16 +450,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
|||||||
uncond_other = []
|
uncond_other = []
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
if 'area' not in x[1]:
|
if 'area' not in x:
|
||||||
if name in x[1] and x[1][name] is not None:
|
if name in x and x[name] is not None:
|
||||||
cond_cnets.append(x[1][name])
|
cond_cnets.append(x[name])
|
||||||
else:
|
else:
|
||||||
cond_other.append((x, t))
|
cond_other.append((x, t))
|
||||||
for t in range(len(uncond)):
|
for t in range(len(uncond)):
|
||||||
x = uncond[t]
|
x = uncond[t]
|
||||||
if 'area' not in x[1]:
|
if 'area' not in x:
|
||||||
if name in x[1] and x[1][name] is not None:
|
if name in x and x[name] is not None:
|
||||||
uncond_cnets.append(x[1][name])
|
uncond_cnets.append(x[name])
|
||||||
else:
|
else:
|
||||||
uncond_other.append((x, t))
|
uncond_other.append((x, t))
|
||||||
|
|
||||||
@ -509,47 +469,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
|||||||
for x in range(len(cond_cnets)):
|
for x in range(len(cond_cnets)):
|
||||||
temp = uncond_other[x % len(uncond_other)]
|
temp = uncond_other[x % len(uncond_other)]
|
||||||
o = temp[0]
|
o = temp[0]
|
||||||
if name in o[1] and o[1][name] is not None:
|
if name in o and o[name] is not None:
|
||||||
n = o[1].copy()
|
n = o.copy()
|
||||||
n[name] = uncond_fill_func(cond_cnets, x)
|
n[name] = uncond_fill_func(cond_cnets, x)
|
||||||
uncond += [[o[0], n]]
|
uncond += [n]
|
||||||
else:
|
else:
|
||||||
n = o[1].copy()
|
n = o.copy()
|
||||||
n[name] = uncond_fill_func(cond_cnets, x)
|
n[name] = uncond_fill_func(cond_cnets, x)
|
||||||
uncond[temp[1]] = [o[0], n]
|
uncond[temp[1]] = n
|
||||||
|
|
||||||
def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
adm_out = None
|
params = x.copy()
|
||||||
if 'adm' in x[1]:
|
|
||||||
adm_out = x[1]["adm"]
|
|
||||||
else:
|
|
||||||
params = x[1].copy()
|
|
||||||
params["width"] = params.get("width", width * 8)
|
|
||||||
params["height"] = params.get("height", height * 8)
|
|
||||||
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
|
||||||
adm_out = model.encode_adm(device=device, **params)
|
|
||||||
|
|
||||||
if adm_out is not None:
|
|
||||||
x[1] = x[1].copy()
|
|
||||||
x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device)
|
|
||||||
|
|
||||||
return conds
|
|
||||||
|
|
||||||
def encode_cond(model_function, key, conds, device, **kwargs):
|
|
||||||
for t in range(len(conds)):
|
|
||||||
x = conds[t]
|
|
||||||
params = x[1].copy()
|
|
||||||
params["device"] = device
|
params["device"] = device
|
||||||
|
params["noise"] = noise
|
||||||
|
params["width"] = params.get("width", noise.shape[3] * 8)
|
||||||
|
params["height"] = params.get("height", noise.shape[2] * 8)
|
||||||
|
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
||||||
for k in kwargs:
|
for k in kwargs:
|
||||||
if k not in params:
|
if k not in params:
|
||||||
params[k] = kwargs[k]
|
params[k] = kwargs[k]
|
||||||
|
|
||||||
out = model_function(**params)
|
out = model_function(**params)
|
||||||
if out is not None:
|
x = x.copy()
|
||||||
x[1] = x[1].copy()
|
model_conds = x['model_conds'].copy()
|
||||||
x[1][key] = out
|
for k in out:
|
||||||
|
model_conds[k] = out[k]
|
||||||
|
x['model_conds'] = model_conds
|
||||||
|
conds[t] = x
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
class Sampler:
|
class Sampler:
|
||||||
@ -557,42 +505,9 @@ class Sampler:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def max_denoise(self, model_wrap, sigmas):
|
def max_denoise(self, model_wrap, sigmas):
|
||||||
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
|
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
|
||||||
|
sigma = float(sigmas[0])
|
||||||
class DDIM(Sampler):
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
|
||||||
timesteps = []
|
|
||||||
for s in range(sigmas.shape[0]):
|
|
||||||
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
|
||||||
noise_mask = None
|
|
||||||
if denoise_mask is not None:
|
|
||||||
noise_mask = 1.0 - denoise_mask
|
|
||||||
|
|
||||||
ddim_callback = None
|
|
||||||
if callback is not None:
|
|
||||||
total_steps = len(timesteps) - 1
|
|
||||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
|
||||||
|
|
||||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
|
||||||
|
|
||||||
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
|
|
||||||
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
|
||||||
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
|
|
||||||
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
|
|
||||||
batch_size=noise.shape[0],
|
|
||||||
shape=noise.shape[1:],
|
|
||||||
verbose=False,
|
|
||||||
eta=0.0,
|
|
||||||
x_T=z_enc,
|
|
||||||
x0=latent_image,
|
|
||||||
img_callback=ddim_callback,
|
|
||||||
denoise_function=model_wrap.predict_eps_discrete_timestep,
|
|
||||||
extra_args=extra_args,
|
|
||||||
mask=noise_mask,
|
|
||||||
to_zero=sigmas[-1]==0,
|
|
||||||
end_step=sigmas.shape[0] - 1,
|
|
||||||
disable_pbar=disable_pbar)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
class UNIPC(Sampler):
|
class UNIPC(Sampler):
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
@ -606,13 +521,17 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral"
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
||||||
|
|
||||||
def ksampler(sampler_name, extra_options={}):
|
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
extra_args["denoise_mask"] = denoise_mask
|
extra_args["denoise_mask"] = denoise_mask
|
||||||
model_k = KSamplerX0Inpaint(model_wrap)
|
model_k = KSamplerX0Inpaint(model_wrap)
|
||||||
model_k.latent_image = latent_image
|
model_k.latent_image = latent_image
|
||||||
model_k.noise = noise
|
if inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||||
|
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||||
|
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
|
||||||
|
else:
|
||||||
|
model_k.noise = noise
|
||||||
|
|
||||||
if self.max_denoise(model_wrap, sigmas):
|
if self.max_denoise(model_wrap, sigmas):
|
||||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
@ -641,11 +560,7 @@ def ksampler(sampler_name, extra_options={}):
|
|||||||
|
|
||||||
def wrap_model(model):
|
def wrap_model(model):
|
||||||
model_denoise = CFGNoisePredictor(model)
|
model_denoise = CFGNoisePredictor(model)
|
||||||
if model.model_type == model_base.ModelType.V_PREDICTION:
|
return model_denoise
|
||||||
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
|
|
||||||
else:
|
|
||||||
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
|
|
||||||
return model_wrap
|
|
||||||
|
|
||||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
positive = positive[:]
|
positive = positive[:]
|
||||||
@ -656,8 +571,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
|
|
||||||
model_wrap = wrap_model(model)
|
model_wrap = wrap_model(model)
|
||||||
|
|
||||||
calculate_start_end_timesteps(model_wrap, negative)
|
calculate_start_end_timesteps(model, negative)
|
||||||
calculate_start_end_timesteps(model_wrap, positive)
|
calculate_start_end_timesteps(model, positive)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for c in positive:
|
||||||
@ -665,21 +580,17 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
for c in negative:
|
for c in negative:
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
create_cond_with_same_area_if_none(positive, c)
|
||||||
|
|
||||||
pre_run_control(model_wrap, negative + positive)
|
pre_run_control(model, negative + positive)
|
||||||
|
|
||||||
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
if latent_image is not None:
|
if latent_image is not None:
|
||||||
latent_image = model.process_latent_in(latent_image)
|
latent_image = model.process_latent_in(latent_image)
|
||||||
|
|
||||||
if model.is_adm():
|
if hasattr(model, 'extra_conds'):
|
||||||
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||||
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||||
|
|
||||||
if hasattr(model, 'cond_concat'):
|
|
||||||
positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
|
||||||
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||||
|
|
||||||
@ -690,19 +601,18 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
|
|||||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||||
model_wrap = wrap_model(model)
|
|
||||||
if scheduler_name == "karras":
|
if scheduler_name == "karras":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||||
elif scheduler_name == "exponential":
|
elif scheduler_name == "exponential":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||||
elif scheduler_name == "normal":
|
elif scheduler_name == "normal":
|
||||||
sigmas = model_wrap.get_sigmas(steps)
|
sigmas = normal_scheduler(model, steps)
|
||||||
elif scheduler_name == "simple":
|
elif scheduler_name == "simple":
|
||||||
sigmas = simple_scheduler(model_wrap, steps)
|
sigmas = simple_scheduler(model, steps)
|
||||||
elif scheduler_name == "ddim_uniform":
|
elif scheduler_name == "ddim_uniform":
|
||||||
sigmas = ddim_scheduler(model_wrap, steps)
|
sigmas = ddim_scheduler(model, steps)
|
||||||
elif scheduler_name == "sgm_uniform":
|
elif scheduler_name == "sgm_uniform":
|
||||||
sigmas = sgm_scheduler(model_wrap, steps)
|
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||||
else:
|
else:
|
||||||
print("error invalid scheduler", self.scheduler)
|
print("error invalid scheduler", self.scheduler)
|
||||||
return sigmas
|
return sigmas
|
||||||
@ -713,7 +623,7 @@ def sampler_class(name):
|
|||||||
elif name == "uni_pc_bh2":
|
elif name == "uni_pc_bh2":
|
||||||
sampler = UNIPCBH2
|
sampler = UNIPCBH2
|
||||||
elif name == "ddim":
|
elif name == "ddim":
|
||||||
sampler = DDIM
|
sampler = ksampler("euler", inpaint_options={"random": True})
|
||||||
else:
|
else:
|
||||||
sampler = ksampler(name)
|
sampler = ksampler(name)
|
||||||
return sampler
|
return sampler
|
||||||
|
|||||||
33
comfy/sd.py
33
comfy/sd.py
@ -55,13 +55,26 @@ def load_clip_weights(model, sd):
|
|||||||
|
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
key_map = comfy.lora.model_lora_keys_unet(model.model)
|
key_map = {}
|
||||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
if model is not None:
|
||||||
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
|
if clip is not None:
|
||||||
|
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||||
|
|
||||||
loaded = comfy.lora.load_lora(lora, key_map)
|
loaded = comfy.lora.load_lora(lora, key_map)
|
||||||
new_modelpatcher = model.clone()
|
if model is not None:
|
||||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
new_modelpatcher = model.clone()
|
||||||
new_clip = clip.clone()
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
else:
|
||||||
|
k = ()
|
||||||
|
new_modelpatcher = None
|
||||||
|
|
||||||
|
if clip is not None:
|
||||||
|
new_clip = clip.clone()
|
||||||
|
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||||
|
else:
|
||||||
|
k1 = ()
|
||||||
|
new_clip = None
|
||||||
k = set(k)
|
k = set(k)
|
||||||
k1 = set(k1)
|
k1 = set(k1)
|
||||||
for x in loaded:
|
for x in loaded:
|
||||||
@ -360,7 +373,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||||
model_config.unet_config = unet_config
|
model_config.unet_config = model_detection.convert_config(unet_config)
|
||||||
|
|
||||||
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||||
@ -388,11 +401,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
||||||
clip_target.clip = sd2_clip.SD2ClipModel
|
clip_target.clip = sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
|
w.cond_stage_model = clip.cond_stage_model.clip_h
|
||||||
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
|
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
w.cond_stage_model = clip.cond_stage_model
|
w.cond_stage_model = clip.cond_stage_model.clip_l
|
||||||
load_clip_weights(w, state_dict)
|
load_clip_weights(w, state_dict)
|
||||||
|
|
||||||
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class ClipTokenWeightEncoder:
|
|||||||
return z_empty.cpu(), first_pooled.cpu()
|
return z_empty.cpu(), first_pooled.cpu()
|
||||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
LAYERS = [
|
LAYERS = [
|
||||||
"last",
|
"last",
|
||||||
@ -278,7 +278,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
|
|
||||||
valid_file = None
|
valid_file = None
|
||||||
for embed_dir in embedding_directory:
|
for embed_dir in embedding_directory:
|
||||||
embed_path = os.path.join(embed_dir, embedding_name)
|
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
|
||||||
|
embed_dir = os.path.abspath(embed_dir)
|
||||||
|
try:
|
||||||
|
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
continue
|
||||||
if not os.path.isfile(embed_path):
|
if not os.path.isfile(embed_path):
|
||||||
extensions = ['.safetensors', '.pt', '.bin']
|
extensions = ['.safetensors', '.pt', '.bin']
|
||||||
for x in extensions:
|
for x in extensions:
|
||||||
@ -336,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
embed_out = next(iter(values))
|
embed_out = next(iter(values))
|
||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SD1Tokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
@ -448,3 +454,40 @@ class SD1Tokenizer:
|
|||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||||
|
|
||||||
|
|
||||||
|
class SD1Tokenizer:
|
||||||
|
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
|
||||||
|
self.clip_name = clip_name
|
||||||
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
|
out = {}
|
||||||
|
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
|
||||||
|
class SD1ClipModel(torch.nn.Module):
|
||||||
|
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.clip_name = clip_name
|
||||||
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||||
|
|
||||||
|
def clip_layer(self, layer_idx):
|
||||||
|
getattr(self, self.clip).clip_layer(layer_idx)
|
||||||
|
|
||||||
|
def reset_clip_layer(self):
|
||||||
|
getattr(self, self.clip).reset_clip_layer()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||||
|
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||||
|
return out, pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return getattr(self, self.clip).load_sd(sd)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from comfy import sd1_clip
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
@ -12,6 +12,14 @@ class SD2ClipModel(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||||
|
|
||||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||||
|
|
||||||
|
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||||
|
|
||||||
|
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from comfy import sd1_clip
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class SDXLClipG(sd1_clip.SD1ClipModel):
|
class SDXLClipG(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
@ -16,14 +16,14 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
|
|||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
|
|
||||||
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||||
|
|
||||||
|
|
||||||
class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
class SDXLTokenizer:
|
||||||
def __init__(self, embedding_directory=None):
|
def __init__(self, embedding_directory=None):
|
||||||
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@ -38,7 +38,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||||
self.clip_l.layer_norm_hidden_state = False
|
self.clip_l.layer_norm_hidden_state = False
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||||
|
|
||||||
@ -63,21 +63,6 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.clip_l.load_sd(sd)
|
return self.clip_l.load_sd(sd)
|
||||||
|
|
||||||
class SDXLRefinerClipModel(torch.nn.Module):
|
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
|
||||||
self.clip_g.clip_layer(layer_idx)
|
|
||||||
|
|
||||||
def reset_clip_layer(self):
|
|
||||||
self.clip_g.reset_clip_layer()
|
|
||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
|
||||||
token_weight_pairs_g = token_weight_pairs["g"]
|
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
|
||||||
return g_out, g_pooled
|
|
||||||
|
|
||||||
def load_sd(self, sd):
|
|
||||||
return self.clip_g.load_sd(sd)
|
|
||||||
|
|||||||
@ -38,8 +38,15 @@ class SD15(supported_models_base.BASE):
|
|||||||
if ids.dtype == torch.float32:
|
if ids.dtype == torch.float32:
|
||||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||||
|
|
||||||
|
replace_prefix = {}
|
||||||
|
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
||||||
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||||
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||||
|
|
||||||
@ -62,12 +69,12 @@ class SD20(supported_models_base.BASE):
|
|||||||
return model_base.ModelType.EPS
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
replace_prefix[""] = "cond_stage_model.model."
|
replace_prefix["clip_h"] = "cond_stage_model.model"
|
||||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||||
return state_dict
|
return state_dict
|
||||||
@ -104,7 +111,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
"use_linear_in_transformer": True,
|
"use_linear_in_transformer": True,
|
||||||
"context_dim": 1280,
|
"context_dim": 1280,
|
||||||
"adm_in_channels": 2560,
|
"adm_in_channels": 2560,
|
||||||
"transformer_depth": [0, 4, 4, 0],
|
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
|
||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
@ -139,7 +146,7 @@ class SDXL(supported_models_base.BASE):
|
|||||||
unet_config = {
|
unet_config = {
|
||||||
"model_channels": 320,
|
"model_channels": 320,
|
||||||
"use_linear_in_transformer": True,
|
"use_linear_in_transformer": True,
|
||||||
"transformer_depth": [0, 2, 10],
|
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||||
"context_dim": 2048,
|
"context_dim": 2048,
|
||||||
"adm_in_channels": 2816
|
"adm_in_channels": 2816
|
||||||
}
|
}
|
||||||
@ -165,6 +172,7 @@ class SDXL(supported_models_base.BASE):
|
|||||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
||||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||||
|
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
|
||||||
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||||
|
|
||||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
@ -189,5 +197,14 @@ class SDXL(supported_models_base.BASE):
|
|||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||||
|
|
||||||
|
class SSD1B(SDXL):
|
||||||
|
unet_config = {
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": True,
|
||||||
|
"transformer_depth": [0, 0, 2, 2, 4, 4],
|
||||||
|
"context_dim": 2048,
|
||||||
|
"adm_in_channels": 2816
|
||||||
|
}
|
||||||
|
|
||||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL]
|
|
||||||
|
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
|
||||||
|
|||||||
@ -170,25 +170,12 @@ UNET_MAP_BASIC = {
|
|||||||
|
|
||||||
def unet_to_diffusers(unet_config):
|
def unet_to_diffusers(unet_config):
|
||||||
num_res_blocks = unet_config["num_res_blocks"]
|
num_res_blocks = unet_config["num_res_blocks"]
|
||||||
attention_resolutions = unet_config["attention_resolutions"]
|
|
||||||
channel_mult = unet_config["channel_mult"]
|
channel_mult = unet_config["channel_mult"]
|
||||||
transformer_depth = unet_config["transformer_depth"]
|
transformer_depth = unet_config["transformer_depth"][:]
|
||||||
|
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
||||||
num_blocks = len(channel_mult)
|
num_blocks = len(channel_mult)
|
||||||
if isinstance(num_res_blocks, int):
|
|
||||||
num_res_blocks = [num_res_blocks] * num_blocks
|
|
||||||
if isinstance(transformer_depth, int):
|
|
||||||
transformer_depth = [transformer_depth] * num_blocks
|
|
||||||
|
|
||||||
transformers_per_layer = []
|
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
||||||
res = 1
|
|
||||||
for i in range(num_blocks):
|
|
||||||
transformers = 0
|
|
||||||
if res in attention_resolutions:
|
|
||||||
transformers = transformer_depth[i]
|
|
||||||
transformers_per_layer.append(transformers)
|
|
||||||
res *= 2
|
|
||||||
|
|
||||||
transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
|
|
||||||
|
|
||||||
diffusers_unet_map = {}
|
diffusers_unet_map = {}
|
||||||
for x in range(num_blocks):
|
for x in range(num_blocks):
|
||||||
@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config):
|
|||||||
for i in range(num_res_blocks[x]):
|
for i in range(num_res_blocks[x]):
|
||||||
for b in UNET_MAP_RESNET:
|
for b in UNET_MAP_RESNET:
|
||||||
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||||
if transformers_per_layer[x] > 0:
|
num_transformers = transformer_depth.pop(0)
|
||||||
|
if num_transformers > 0:
|
||||||
for b in UNET_MAP_ATTENTIONS:
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||||
for t in range(transformers_per_layer[x]):
|
for t in range(num_transformers):
|
||||||
for b in TRANSFORMER_BLOCKS:
|
for b in TRANSFORMER_BLOCKS:
|
||||||
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
n += 1
|
n += 1
|
||||||
@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config):
|
|||||||
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||||
|
|
||||||
num_res_blocks = list(reversed(num_res_blocks))
|
num_res_blocks = list(reversed(num_res_blocks))
|
||||||
transformers_per_layer = list(reversed(transformers_per_layer))
|
|
||||||
for x in range(num_blocks):
|
for x in range(num_blocks):
|
||||||
n = (num_res_blocks[x] + 1) * x
|
n = (num_res_blocks[x] + 1) * x
|
||||||
l = num_res_blocks[x] + 1
|
l = num_res_blocks[x] + 1
|
||||||
@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config):
|
|||||||
for b in UNET_MAP_RESNET:
|
for b in UNET_MAP_RESNET:
|
||||||
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||||
c += 1
|
c += 1
|
||||||
if transformers_per_layer[x] > 0:
|
num_transformers = transformer_depth_output.pop()
|
||||||
|
if num_transformers > 0:
|
||||||
c += 1
|
c += 1
|
||||||
for b in UNET_MAP_ATTENTIONS:
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||||
for t in range(transformers_per_layer[x]):
|
for t in range(num_transformers):
|
||||||
for b in TRANSFORMER_BLOCKS:
|
for b in TRANSFORMER_BLOCKS:
|
||||||
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
if i == l - 1:
|
if i == l - 1:
|
||||||
|
|||||||
@ -126,7 +126,7 @@ class Quantize:
|
|||||||
"max": 256,
|
"max": 256,
|
||||||
"step": 1
|
"step": 1
|
||||||
}),
|
}),
|
||||||
"dither": (["none", "floyd-steinberg"],),
|
"dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,19 +135,47 @@ class Quantize:
|
|||||||
|
|
||||||
CATEGORY = "image/postprocessing"
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
|
def bayer(im, pal_im, order):
|
||||||
|
def normalized_bayer_matrix(n):
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros((1,1), "float32")
|
||||||
|
else:
|
||||||
|
q = 4 ** n
|
||||||
|
m = q * normalized_bayer_matrix(n - 1)
|
||||||
|
return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
|
||||||
|
|
||||||
|
num_colors = len(pal_im.getpalette()) // 3
|
||||||
|
spread = 2 * 256 / num_colors
|
||||||
|
bayer_n = int(math.log2(order))
|
||||||
|
bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
|
||||||
|
|
||||||
|
result = torch.from_numpy(np.array(im).astype(np.float32))
|
||||||
|
tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
|
||||||
|
th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
|
||||||
|
tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
|
||||||
|
result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
|
||||||
|
result = result.to(dtype=torch.uint8)
|
||||||
|
|
||||||
|
im = Image.fromarray(result.cpu().numpy())
|
||||||
|
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||||
|
return im
|
||||||
|
|
||||||
|
def quantize(self, image: torch.Tensor, colors: int, dither: str):
|
||||||
batch_size, height, width, _ = image.shape
|
batch_size, height, width, _ = image.shape
|
||||||
result = torch.zeros_like(image)
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
|
|
||||||
|
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
tensor_image = image[b]
|
im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
|
||||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
|
||||||
pil_image = Image.fromarray(img, mode='RGB')
|
|
||||||
|
|
||||||
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||||
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
|
|
||||||
|
if dither == "none":
|
||||||
|
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||||
|
elif dither == "floyd-steinberg":
|
||||||
|
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
|
||||||
|
elif dither.startswith("bayer"):
|
||||||
|
order = int(dither.split('-')[-1])
|
||||||
|
quantized_image = Quantize.bayer(im, pal_im, order)
|
||||||
|
|
||||||
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
||||||
result[b] = quantized_array
|
result[b] = quantized_array
|
||||||
|
|||||||
@ -4,7 +4,7 @@ class LatentRebatch:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "latents": ("LATENT",),
|
return {"required": { "latents": ("LATENT",),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
INPUT_IS_LIST = True
|
INPUT_IS_LIST = True
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
self.taesd = taesd
|
self.taesd = taesd
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
x_sample = self.taesd.decoder(x0)[0].detach()
|
x_sample = self.taesd.decoder(x0[:1])[0].detach()
|
||||||
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
||||||
x_sample = x_sample.sub(0.5).mul(2)
|
x_sample = x_sample.sub(0.5).mul(2)
|
||||||
|
|
||||||
|
|||||||
@ -82,7 +82,8 @@ class PromptServer():
|
|||||||
if args.enable_cors_header:
|
if args.enable_cors_header:
|
||||||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||||
|
|
||||||
self.app = web.Application(client_max_size=104857600, middlewares=middlewares)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
self.web_root = os.path.join(os.path.dirname(
|
self.web_root = os.path.join(os.path.dirname(
|
||||||
os.path.realpath(__file__)), "web")
|
os.path.realpath(__file__)), "web")
|
||||||
|
|||||||
@ -25,7 +25,7 @@ const ext = {
|
|||||||
requestAnimationFrame(() => {
|
requestAnimationFrame(() => {
|
||||||
const currentNode = LGraphCanvas.active_canvas.current_node;
|
const currentNode = LGraphCanvas.active_canvas.current_node;
|
||||||
const clickedComboValue = currentNode.widgets
|
const clickedComboValue = currentNode.widgets
|
||||||
.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
?.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
||||||
.find(w => w.options.values.every((v, i) => v === values[i]))
|
.find(w => w.options.values.every((v, i) => v === values[i]))
|
||||||
?.value;
|
?.value;
|
||||||
|
|
||||||
|
|||||||
@ -15,6 +15,9 @@ import { GROUP_DATA, IS_GROUP_NODE, registerGroupNodes } from "./groupNode.js";
|
|||||||
// To delete/rename:
|
// To delete/rename:
|
||||||
// Right click the canvas
|
// Right click the canvas
|
||||||
// Node templates -> Manage
|
// Node templates -> Manage
|
||||||
|
//
|
||||||
|
// To rearrange:
|
||||||
|
// Open the manage dialog and Drag and drop elements using the "Name:" label as handle
|
||||||
|
|
||||||
const id = "Comfy.NodeTemplates";
|
const id = "Comfy.NodeTemplates";
|
||||||
|
|
||||||
@ -23,6 +26,10 @@ class ManageTemplates extends ComfyDialog {
|
|||||||
super();
|
super();
|
||||||
this.element.classList.add("comfy-manage-templates");
|
this.element.classList.add("comfy-manage-templates");
|
||||||
this.templates = this.load();
|
this.templates = this.load();
|
||||||
|
this.draggedEl = null;
|
||||||
|
this.saveVisualCue = null;
|
||||||
|
this.emptyImg = new Image();
|
||||||
|
this.emptyImg.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs=';
|
||||||
|
|
||||||
this.importInput = $el("input", {
|
this.importInput = $el("input", {
|
||||||
type: "file",
|
type: "file",
|
||||||
@ -36,14 +43,11 @@ class ManageTemplates extends ComfyDialog {
|
|||||||
|
|
||||||
createButtons() {
|
createButtons() {
|
||||||
const btns = super.createButtons();
|
const btns = super.createButtons();
|
||||||
btns[0].textContent = "Cancel";
|
btns[0].textContent = "Close";
|
||||||
btns.unshift(
|
btns[0].onclick = (e) => {
|
||||||
$el("button", {
|
clearTimeout(this.saveVisualCue);
|
||||||
type: "button",
|
this.close();
|
||||||
textContent: "Save",
|
};
|
||||||
onclick: () => this.save(),
|
|
||||||
})
|
|
||||||
);
|
|
||||||
btns.unshift(
|
btns.unshift(
|
||||||
$el("button", {
|
$el("button", {
|
||||||
type: "button",
|
type: "button",
|
||||||
@ -72,25 +76,6 @@ class ManageTemplates extends ComfyDialog {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
save() {
|
|
||||||
// Find all visible inputs and save them as our new list
|
|
||||||
const inputs = this.element.querySelectorAll("input");
|
|
||||||
const updated = [];
|
|
||||||
|
|
||||||
for (let i = 0; i < inputs.length; i++) {
|
|
||||||
const input = inputs[i];
|
|
||||||
if (input.parentElement.style.display !== "none") {
|
|
||||||
const t = this.templates[i];
|
|
||||||
t.name = input.value.trim() || input.getAttribute("data-name");
|
|
||||||
updated.push(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.templates = updated;
|
|
||||||
this.store();
|
|
||||||
this.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
store() {
|
store() {
|
||||||
localStorage.setItem(id, JSON.stringify(this.templates));
|
localStorage.setItem(id, JSON.stringify(this.templates));
|
||||||
}
|
}
|
||||||
@ -146,71 +131,155 @@ class ManageTemplates extends ComfyDialog {
|
|||||||
super.show(
|
super.show(
|
||||||
$el(
|
$el(
|
||||||
"div",
|
"div",
|
||||||
{
|
{},
|
||||||
style: {
|
this.templates.flatMap((t,i) => {
|
||||||
display: "grid",
|
|
||||||
gridTemplateColumns: "1fr auto",
|
|
||||||
gap: "5px",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
this.templates.flatMap((t) => {
|
|
||||||
let nameInput;
|
let nameInput;
|
||||||
return [
|
return [
|
||||||
$el(
|
$el(
|
||||||
"label",
|
"div",
|
||||||
{
|
{
|
||||||
textContent: "Name: ",
|
dataset: { id: i },
|
||||||
|
className: "tempateManagerRow",
|
||||||
|
style: {
|
||||||
|
display: "grid",
|
||||||
|
gridTemplateColumns: "1fr auto",
|
||||||
|
border: "1px dashed transparent",
|
||||||
|
gap: "5px",
|
||||||
|
backgroundColor: "var(--comfy-menu-bg)"
|
||||||
|
},
|
||||||
|
ondragstart: (e) => {
|
||||||
|
this.draggedEl = e.currentTarget;
|
||||||
|
e.currentTarget.style.opacity = "0.6";
|
||||||
|
e.currentTarget.style.border = "1px dashed yellow";
|
||||||
|
e.dataTransfer.effectAllowed = 'move';
|
||||||
|
e.dataTransfer.setDragImage(this.emptyImg, 0, 0);
|
||||||
|
},
|
||||||
|
ondragend: (e) => {
|
||||||
|
e.target.style.opacity = "1";
|
||||||
|
e.currentTarget.style.border = "1px dashed transparent";
|
||||||
|
e.currentTarget.removeAttribute("draggable");
|
||||||
|
|
||||||
|
// rearrange the elements in the localStorage
|
||||||
|
this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
|
||||||
|
var prev_i = el.dataset.id;
|
||||||
|
|
||||||
|
if ( el == this.draggedEl && prev_i != i ) {
|
||||||
|
[this.templates[i], this.templates[prev_i]] = [this.templates[prev_i], this.templates[i]];
|
||||||
|
}
|
||||||
|
el.dataset.id = i;
|
||||||
|
});
|
||||||
|
this.store();
|
||||||
|
},
|
||||||
|
ondragover: (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
if ( e.currentTarget == this.draggedEl )
|
||||||
|
return;
|
||||||
|
|
||||||
|
let rect = e.currentTarget.getBoundingClientRect();
|
||||||
|
if (e.clientY > rect.top + rect.height / 2) {
|
||||||
|
e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget.nextSibling);
|
||||||
|
} else {
|
||||||
|
e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget);
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
$el("input", {
|
$el(
|
||||||
value: t.name,
|
"label",
|
||||||
dataset: { name: t.name },
|
{
|
||||||
$: (el) => (nameInput = el),
|
textContent: "Name: ",
|
||||||
}),
|
style: {
|
||||||
|
cursor: "grab",
|
||||||
|
},
|
||||||
|
onmousedown: (e) => {
|
||||||
|
// enable dragging only from the label
|
||||||
|
if (e.target.localName == 'label')
|
||||||
|
e.currentTarget.parentNode.draggable = 'true';
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[
|
||||||
|
$el("input", {
|
||||||
|
value: t.name,
|
||||||
|
dataset: { name: t.name },
|
||||||
|
style: {
|
||||||
|
transitionProperty: 'background-color',
|
||||||
|
transitionDuration: '0s',
|
||||||
|
},
|
||||||
|
onchange: (e) => {
|
||||||
|
clearTimeout(this.saveVisualCue);
|
||||||
|
var el = e.target;
|
||||||
|
var row = el.parentNode.parentNode;
|
||||||
|
this.templates[row.dataset.id].name = el.value.trim() || 'untitled';
|
||||||
|
this.store();
|
||||||
|
el.style.backgroundColor = 'rgb(40, 95, 40)';
|
||||||
|
el.style.transitionDuration = '0s';
|
||||||
|
this.saveVisualCue = setTimeout(function () {
|
||||||
|
el.style.transitionDuration = '.7s';
|
||||||
|
el.style.backgroundColor = 'var(--comfy-input-bg)';
|
||||||
|
}, 15);
|
||||||
|
},
|
||||||
|
onkeypress: (e) => {
|
||||||
|
var el = e.target;
|
||||||
|
clearTimeout(this.saveVisualCue);
|
||||||
|
el.style.transitionDuration = '0s';
|
||||||
|
el.style.backgroundColor = 'var(--comfy-input-bg)';
|
||||||
|
},
|
||||||
|
$: (el) => (nameInput = el),
|
||||||
|
})
|
||||||
|
]
|
||||||
|
),
|
||||||
|
$el(
|
||||||
|
"div",
|
||||||
|
{},
|
||||||
|
[
|
||||||
|
$el("button", {
|
||||||
|
textContent: "Export",
|
||||||
|
style: {
|
||||||
|
fontSize: "12px",
|
||||||
|
fontWeight: "normal",
|
||||||
|
},
|
||||||
|
onclick: (e) => {
|
||||||
|
const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
|
||||||
|
const blob = new Blob([json], {type: "application/json"});
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = $el("a", {
|
||||||
|
href: url,
|
||||||
|
download: (nameInput.value || t.name) + ".json",
|
||||||
|
style: {display: "none"},
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function () {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
$el("button", {
|
||||||
|
textContent: "Delete",
|
||||||
|
style: {
|
||||||
|
fontSize: "12px",
|
||||||
|
color: "red",
|
||||||
|
fontWeight: "normal",
|
||||||
|
},
|
||||||
|
onclick: (e) => {
|
||||||
|
const item = e.target.parentNode.parentNode;
|
||||||
|
item.parentNode.removeChild(item);
|
||||||
|
this.templates.splice(item.dataset.id*1, 1);
|
||||||
|
this.store();
|
||||||
|
// update the rows index, setTimeout ensures that the list is updated
|
||||||
|
var that = this;
|
||||||
|
setTimeout(function (){
|
||||||
|
that.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
|
||||||
|
el.dataset.id = i;
|
||||||
|
});
|
||||||
|
}, 0);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
]
|
||||||
|
),
|
||||||
]
|
]
|
||||||
),
|
)
|
||||||
$el(
|
|
||||||
"div",
|
|
||||||
{},
|
|
||||||
[
|
|
||||||
$el("button", {
|
|
||||||
textContent: "Export",
|
|
||||||
style: {
|
|
||||||
fontSize: "12px",
|
|
||||||
fontWeight: "normal",
|
|
||||||
},
|
|
||||||
onclick: (e) => {
|
|
||||||
const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
|
|
||||||
const blob = new Blob([json], {type: "application/json"});
|
|
||||||
const url = URL.createObjectURL(blob);
|
|
||||||
const a = $el("a", {
|
|
||||||
href: url,
|
|
||||||
download: (nameInput.value || t.name) + ".json",
|
|
||||||
style: {display: "none"},
|
|
||||||
parent: document.body,
|
|
||||||
});
|
|
||||||
a.click();
|
|
||||||
setTimeout(function () {
|
|
||||||
a.remove();
|
|
||||||
window.URL.revokeObjectURL(url);
|
|
||||||
}, 0);
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
$el("button", {
|
|
||||||
textContent: "Delete",
|
|
||||||
style: {
|
|
||||||
fontSize: "12px",
|
|
||||||
color: "red",
|
|
||||||
fontWeight: "normal",
|
|
||||||
},
|
|
||||||
onclick: (e) => {
|
|
||||||
nameInput.value = "";
|
|
||||||
e.target.parentElement.style.display = "none";
|
|
||||||
e.target.parentElement.previousElementSibling.style.display = "none";
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
];
|
];
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import { ComfyWidgets, getWidgetType } from "./widgets.js";
|
|||||||
import { ComfyUI, $el } from "./ui.js";
|
import { ComfyUI, $el } from "./ui.js";
|
||||||
import { api } from "./api.js";
|
import { api } from "./api.js";
|
||||||
import { defaultGraph } from "./defaultGraph.js";
|
import { defaultGraph } from "./defaultGraph.js";
|
||||||
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
||||||
@ -1601,6 +1601,18 @@ export class ComfyApp {
|
|||||||
* @returns The workflow and node links
|
* @returns The workflow and node links
|
||||||
*/
|
*/
|
||||||
async graphToPrompt() {
|
async graphToPrompt() {
|
||||||
|
for (const outerNode of this.graph.computeExecutionOrder(false)) {
|
||||||
|
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
|
||||||
|
for (const node of innerNodes) {
|
||||||
|
if (node.isVirtualNode) {
|
||||||
|
// Don't serialize frontend only nodes but let them make changes
|
||||||
|
if (node.applyToGraph) {
|
||||||
|
node.applyToGraph();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const workflow = this.graph.serialize();
|
const workflow = this.graph.serialize();
|
||||||
const output = {};
|
const output = {};
|
||||||
// Process nodes in order of execution
|
// Process nodes in order of execution
|
||||||
@ -1608,10 +1620,6 @@ export class ComfyApp {
|
|||||||
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
|
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
|
||||||
for (const node of innerNodes) {
|
for (const node of innerNodes) {
|
||||||
if (node.isVirtualNode) {
|
if (node.isVirtualNode) {
|
||||||
// Don't serialize frontend only nodes but let them make changes
|
|
||||||
if (node.applyToGraph) {
|
|
||||||
node.applyToGraph(workflow);
|
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1809,6 +1817,15 @@ export class ComfyApp {
|
|||||||
importA1111(this.graph, pngInfo.parameters);
|
importA1111(this.graph, pngInfo.parameters);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (file.type === "image/webp") {
|
||||||
|
const pngInfo = await getWebpMetadata(file);
|
||||||
|
if (pngInfo) {
|
||||||
|
if (pngInfo.workflow) {
|
||||||
|
this.loadGraphData(JSON.parse(pngInfo.workflow));
|
||||||
|
} else if (pngInfo.Workflow) {
|
||||||
|
this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node.
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
|
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
|
||||||
const reader = new FileReader();
|
const reader = new FileReader();
|
||||||
reader.onload = async () => {
|
reader.onload = async () => {
|
||||||
|
|||||||
@ -47,6 +47,103 @@ export function getPngMetadata(file) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function parseExifData(exifData) {
|
||||||
|
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
|
||||||
|
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
|
||||||
|
console.log(exifData);
|
||||||
|
|
||||||
|
// Function to read 16-bit and 32-bit integers from binary data
|
||||||
|
function readInt(offset, isLittleEndian, length) {
|
||||||
|
let arr = exifData.slice(offset, offset + length)
|
||||||
|
if (length === 2) {
|
||||||
|
return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint16(0, isLittleEndian);
|
||||||
|
} else if (length === 4) {
|
||||||
|
return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint32(0, isLittleEndian);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the offset to the first IFD (Image File Directory)
|
||||||
|
const ifdOffset = readInt(4, isLittleEndian, 4);
|
||||||
|
|
||||||
|
function parseIFD(offset) {
|
||||||
|
const numEntries = readInt(offset, isLittleEndian, 2);
|
||||||
|
const result = {};
|
||||||
|
|
||||||
|
for (let i = 0; i < numEntries; i++) {
|
||||||
|
const entryOffset = offset + 2 + i * 12;
|
||||||
|
const tag = readInt(entryOffset, isLittleEndian, 2);
|
||||||
|
const type = readInt(entryOffset + 2, isLittleEndian, 2);
|
||||||
|
const numValues = readInt(entryOffset + 4, isLittleEndian, 4);
|
||||||
|
const valueOffset = readInt(entryOffset + 8, isLittleEndian, 4);
|
||||||
|
|
||||||
|
// Read the value(s) based on the data type
|
||||||
|
let value;
|
||||||
|
if (type === 2) {
|
||||||
|
// ASCII string
|
||||||
|
value = String.fromCharCode(...exifData.slice(valueOffset, valueOffset + numValues - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
result[tag] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the first IFD
|
||||||
|
const ifdData = parseIFD(ifdOffset);
|
||||||
|
return ifdData;
|
||||||
|
}
|
||||||
|
|
||||||
|
function splitValues(input) {
|
||||||
|
var output = {};
|
||||||
|
for (var key in input) {
|
||||||
|
var value = input[key];
|
||||||
|
var splitValues = value.split(':', 2);
|
||||||
|
output[splitValues[0]] = splitValues[1];
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getWebpMetadata(file) {
|
||||||
|
return new Promise((r) => {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = (event) => {
|
||||||
|
const webp = new Uint8Array(event.target.result);
|
||||||
|
const dataView = new DataView(webp.buffer);
|
||||||
|
|
||||||
|
// Check that the WEBP signature is present
|
||||||
|
if (dataView.getUint32(0) !== 0x52494646 || dataView.getUint32(8) !== 0x57454250) {
|
||||||
|
console.error("Not a valid WEBP file");
|
||||||
|
r();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start searching for chunks after the WEBP signature
|
||||||
|
let offset = 12;
|
||||||
|
let txt_chunks = {};
|
||||||
|
// Loop through the chunks in the WEBP file
|
||||||
|
while (offset < webp.length) {
|
||||||
|
const chunk_length = dataView.getUint32(offset + 4, true);
|
||||||
|
const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4));
|
||||||
|
if (chunk_type === "EXIF") {
|
||||||
|
let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length));
|
||||||
|
for (var key in data) {
|
||||||
|
var value = data[key];
|
||||||
|
let index = value.indexOf(':');
|
||||||
|
txt_chunks[value.slice(0, index)] = value.slice(index + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += 8 + chunk_length;
|
||||||
|
}
|
||||||
|
|
||||||
|
r(txt_chunks);
|
||||||
|
};
|
||||||
|
|
||||||
|
reader.readAsArrayBuffer(file);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
export function getLatentMetadata(file) {
|
export function getLatentMetadata(file) {
|
||||||
return new Promise((r) => {
|
return new Promise((r) => {
|
||||||
const reader = new FileReader();
|
const reader = new FileReader();
|
||||||
|
|||||||
@ -719,20 +719,22 @@ export class ComfyUI {
|
|||||||
filename += ".json";
|
filename += ".json";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
|
app.graphToPrompt().then(p=>{
|
||||||
const blob = new Blob([json], {type: "application/json"});
|
const json = JSON.stringify(p.workflow, null, 2); // convert the data to a JSON string
|
||||||
const url = URL.createObjectURL(blob);
|
const blob = new Blob([json], {type: "application/json"});
|
||||||
const a = $el("a", {
|
const url = URL.createObjectURL(blob);
|
||||||
href: url,
|
const a = $el("a", {
|
||||||
download: filename,
|
href: url,
|
||||||
style: {display: "none"},
|
download: filename,
|
||||||
parent: document.body,
|
style: {display: "none"},
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function () {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
});
|
});
|
||||||
a.click();
|
|
||||||
setTimeout(function () {
|
|
||||||
a.remove();
|
|
||||||
window.URL.revokeObjectURL(url);
|
|
||||||
}, 0);
|
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
$el("button", {
|
$el("button", {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user