merging master into filesubflows and fixing merge conflicts

This commit is contained in:
Sammy Franklin 2023-11-29 12:53:02 -08:00
commit 32f7fd663a
67 changed files with 4359 additions and 3350 deletions

View File

@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x and SDXL
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
@ -30,6 +30,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.

View File

@ -27,7 +27,6 @@ class ControlNet(nn.Module):
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
@ -52,8 +51,10 @@ class ControlNet(nn.Module):
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=comfy.ops,
**kwargs,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@ -79,10 +80,7 @@ class ControlNet(nn.Module):
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
@ -90,18 +88,16 @@ class ControlNet(nn.Module):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
transformer_depth = transformer_depth[:]
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
@ -180,11 +176,14 @@ class ControlNet(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
operations=operations
dtype=self.dtype,
device=device,
operations=operations,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
@ -201,9 +200,9 @@ class ControlNet(nn.Module):
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, operations=operations
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -223,11 +222,13 @@ class ControlNet(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dtype=self.dtype,
device=device,
operations=operations
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
)
)
)
@ -245,7 +246,7 @@ class ControlNet(nn.Module):
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
mid_block = [
ResBlock(
ch,
time_embed_dim,
@ -253,12 +254,15 @@ class ControlNet(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
SpatialTransformer( # always uses a self-attn
)]
if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, operations=operations
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
),
ResBlock(
ch,
@ -267,9 +271,11 @@ class ControlNet(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
)
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self._feature_size += ch

View File

@ -36,6 +36,8 @@ parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
@ -60,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")

View File

@ -1,5 +1,5 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
from .utils import load_torch_file, transformers_convert
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert, common_upscale
import os
import torch
import contextlib
@ -7,6 +7,18 @@ import contextlib
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.utils
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2]))
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
class ClipVisionModel():
def __init__(self, json_config):
@ -23,25 +35,12 @@ class ClipVisionModel():
self.model.to(self.dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
img = torch.clip((255. * image), 0, 255).round().int()
img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt")
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = inputs['pixel_values'].to(self.load_device)
pixel_values = clip_preprocess(image.to(self.load_device))
if self.dtype != torch.float32:
precision_scope = torch.autocast

79
comfy/conds.py Normal file
View File

@ -0,0 +1,79 @@
import enum
import torch
import math
import comfy.utils
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
class CONDRegular:
def __init__(self, cond):
self.cond = cond
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
return True
def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
class CONDCrossAttn(CONDRegular):
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
return True
def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
return False
return True
def concat(self, others):
return self.cond

View File

@ -33,7 +33,7 @@ class ControlBase:
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (1.0, 0.0)
self.timestep_percent_range = (0.0, 1.0)
self.timestep_range = None
if device is None:
@ -42,7 +42,7 @@ class ControlBase:
self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
@ -132,6 +132,7 @@ class ControlNet(ControlBase):
self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@ -156,10 +157,13 @@ class ControlNet(ControlBase):
context = cond['c_crossattn']
y = cond.get('c_adm', None)
y = cond.get('y', None)
if y is not None:
y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
out.append(self.control_model_wrapped)
return out
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.model_sampling_current = model.model_sampling
def cleanup(self):
self.model_sampling_current = None
super().cleanup()
class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,

View File

@ -852,7 +852,13 @@ class SigmaConvert:
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
def predict_eps_sigma(model, input, sigma_in, **kwargs):
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
input = input * ((sigma ** 2 + 1.0) ** 0.5)
return (input - model(input, sigma_in, **kwargs)) / sigma
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
timesteps = sigmas.clone()
if sigmas[-1] == 0:
timesteps = sigmas[:]
@ -874,14 +880,14 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
model_type = "noise"
model_fn = model_wrapper(
model.predict_eps_sigma,
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
ns,
model_type=model_type,
guidance_type="uncond",
model_kwargs=extra_args,
)
order = min(3, len(timesteps) - 1)
order = min(3, len(timesteps) - 2)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
x /= ns.marginal_alpha(timesteps[-1])

View File

@ -1,194 +0,0 @@
import math
import torch
from torch import nn
from . import sampling, utils
class VDenoiser(nn.Module):
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def sigma_to_t(self, sigma):
return sigma.atan() / math.pi * 2
def t_to_sigma(self, t):
return (t * math.pi / 2).tan()
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
self.quantize = quantize
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def get_sigmas(self, n=None):
if n is None:
return sampling.append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_discrete_timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
if quantize:
return self.sigma_to_discrete_timestep(sigma)
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t-low_idx if t.device.type == 'mps' else t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def predict_eps_discrete_timestep(self, input, t, **kwargs):
if t.dtype != torch.int64 and t.dtype != torch.int32:
t = t.round()
sigma = self.t_to_sigma(t)
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
def predict_eps_sigma(self, input, sigma, **kwargs):
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
return (eps - noise).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for OpenAI diffusion models."""
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas
def get_eps(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
class DiscreteVDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output v."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def get_v(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond)

View File

@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
return mu
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -737,3 +736,75 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
x = denoised
if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_end = sigmas[-1]
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == s_end:
# Euler method
x = x + d * dt
elif sigmas[i + 2] == s_end:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
w = 2 * sigmas[0]
w2 = sigmas[i+1]/w
w1 = 1 - w2
d_prime = d * w1 + d_2 * w2
x = x + d_prime * dt
else:
# Heun++
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
dt_2 = sigmas[i + 2] - sigmas[i + 1]
x_3 = x_2 + d_2 * dt_2
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
w = 3 * sigmas[0]
w2 = sigmas[i + 1] / w
w3 = sigmas[i + 2] / w
w1 = 1 - w2 - w3
d_prime = w1 * d + w2 * d_2 + w3 * d_3
x = x + d_prime * dt
return x

View File

@ -1,418 +0,0 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.parameterization = kwargs.get("parameterization", "eps")
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.float().to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose)
def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True):
self.ddim_timesteps = torch.tensor(ddim_timesteps)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample_custom(self,
ddim_timesteps,
conditioning=None,
callback=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
denoise_function=None,
extra_args=None,
to_zero=True,
end_step=None,
disable_pbar=False,
**kwargs
):
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
samples, intermediates = self.ddim_sampling(conditioning, x_T.shape,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=denoise_function,
extra_args=extra_args,
to_zero=to_zero,
end_step=end_step,
disable_pbar=disable_pbar
)
return samples, intermediates
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=None,
extra_args=None
)
return samples, intermediates
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
device = self.model.alphas_cumprod.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
if to_zero:
img = pred_x0
else:
if ddim_use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
img /= sqrt_alphas_cumprod[index - 1]
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None, denoise_function=None, extra_args=None):
b, *_, device = *x.shape, x.device
if denoise_function is not None:
model_output = denoise_function(x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.parameterization == "v":
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
else:
e_t = model_output
if score_corrector is not None:
assert self.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
if max_denoise:
noise_multiplier = 1.0
else:
noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec

View File

@ -1 +0,0 @@
from .sampler import DPMSolverSampler

File diff suppressed because it is too large Load Diff

View File

@ -1,96 +0,0 @@
"""SAMPLING ONLY."""
import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {
"eps": "noise",
"v": "v"
}
class DPMSolverSampler(object):
def __init__(self, model, device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
if isinstance(ctmp, torch.Tensor):
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
else:
if isinstance(conditioning, torch.Tensor):
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
lower_order_final=True)
return x.to(device), None

View File

@ -1,245 +0,0 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@ -1,22 +0,0 @@
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)

View File

@ -5,8 +5,10 @@ import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from functools import partial
from .diffusionmodules.util import checkpoint
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
@ -160,32 +162,19 @@ def attention_sub_quad(query, key, value, heads, mask=None):
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
kv_chunk_size_min = None
kv_chunk_size = None
query_chunk_size = None
#not sure at all about the math here
#TODO: tweak this
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 4
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 2
else:
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None
for x in [4096, 2048, 1024, 512, 256]:
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
if count >= k_tokens:
kv_chunk_size = k_tokens
query_chunk_size = x
break
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
else:
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
if query_chunk_size is None:
query_chunk_size = 512
hidden_states = efficient_dot_product_attention(
query,
@ -222,9 +211,14 @@ def attention_split(q, k, v, heads, mask=None):
mem_free_total = model_management.get_free_memory(q.device)
if _ATTN_PRECISION =="fp32":
element_size = 4
else:
element_size = q.element_size()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3
mem_required = tensor_size * modifier
steps = 1
@ -252,10 +246,10 @@ def attention_split(q, k, v, heads, mask=None):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
first_op_done = True
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
@ -284,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None):
)
return r1
BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
except:
pass
def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
if BROKEN_XFORMERS:
if b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.unsqueeze(3)
@ -378,53 +383,72 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention:
if switch_temporal_ca_to_sa:
raise ValueError
else:
self.attn2 = None
else:
context_dim_attn2 = None
if not switch_temporal_ca_to_sa:
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
extra_options = {}
block = None
block_index = 0
if "current_index" in transformer_options:
extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
block_index = transformer_options["block_index"]
extra_options["block_index"] = block_index
if "original_shape" in transformer_options:
extra_options["original_shape"] = transformer_options["original_shape"]
if "block" in transformer_options:
block = transformer_options["block"]
extra_options["block"] = block
if "cond_or_uncond" in transformer_options:
extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}
block = transformer_options.get("block", None)
block_index = transformer_options.get("block_index", 0)
transformer_patches = {}
transformer_patches_replace = {}
for k in transformer_options:
if k == "patches":
transformer_patches = transformer_options[k]
elif k == "patches_replace":
transformer_patches_replace = transformer_options[k]
else:
extra_options[k] = transformer_options[k]
extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head
if "patches_replace" in transformer_options:
transformer_patches_replace = transformer_options["patches_replace"]
else:
transformer_patches_replace = {}
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
n = self.norm1(x)
if self.disable_self_attn:
@ -473,31 +497,34 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
x = p(x, extra_options)
n = self.norm2(x)
context_attn2 = context
value_attn2 = None
if "attn2_patch" in transformer_patches:
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
for p in patch:
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch:
block_attn2 = block
if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
if self.attn2 is not None:
n = self.norm2(x)
if self.switch_temporal_ca_to_sa:
context_attn2 = n
else:
context_attn2 = context
value_attn2 = None
if "attn2_patch" in transformer_patches:
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
n = self.attn2.to_q(n)
context_attn2 = self.attn2.to_k(context_attn2)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
for p in patch:
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch:
block_attn2 = block
if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
value_attn2 = context_attn2
n = self.attn2.to_q(n)
context_attn2 = self.attn2.to_k(context_attn2)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
@ -505,7 +532,12 @@ class BasicTransformerBlock(nn.Module):
n = p(n, extra_options)
x += n
x = self.ff(self.norm3(x)) + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
return x
@ -573,3 +605,164 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x)
return x + x_in
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None, device=None, operations=comfy.ops
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_stack = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
# timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)
]
)
assert len(self.time_stack) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
)
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
if time_context is None:
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
transformer_options["block_index"] = it_
x = block(
x,
context=spatial_context,
transformer_options=transformer_options,
)
x_mix = x
x_mix = x_mix + emb
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -5,6 +5,8 @@ import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from functools import partial
from .util import (
checkpoint,
@ -12,8 +14,9 @@ from .util import (
zero_module,
normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
import comfy.ops
@ -28,6 +31,26 @@ class TimestepBlock(nn.Module):
Apply the module to `x` given `emb` timestep embeddings.
"""
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialVideoTransformer):
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
@ -35,31 +58,8 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input.
"""
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
def forward(self, *args, **kwargs):
return forward_timestep_embed(self, *args, **kwargs)
class Upsample(nn.Module):
"""
@ -154,6 +154,9 @@ class ResBlock(TimestepBlock):
use_checkpoint=False,
up=False,
down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
dtype=None,
device=None,
operations=comfy.ops
@ -166,11 +169,17 @@ class ResBlock(TimestepBlock):
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.exchange_temb_dims = exchange_temb_dims
if isinstance(kernel_size, list):
padding = [k // 2 for k in kernel_size]
else:
padding = kernel_size // 2
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
)
self.updown = up or down
@ -184,19 +193,24 @@ class ResBlock(TimestepBlock):
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
operations.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
),
)
self.skip_t_emb = skip_t_emb
if self.skip_t_emb:
self.emb_layers = None
self.exchange_temb_dims = False
else:
self.emb_layers = nn.Sequential(
nn.SiLU(),
operations.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
),
)
@ -204,7 +218,7 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
)
else:
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
@ -230,19 +244,110 @@ class ResBlock(TimestepBlock):
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
emb_out = None
if not self.skip_t_emb:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_norm(h)
if emb_out is not None:
scale, shift = th.chunk(emb_out, 2, dim=1)
h *= (1 + scale)
h += shift
h = out_rest(h)
else:
h = h + emb_out
if emb_out is not None:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class VideoResBlock(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size=3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels=None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
dtype=None,
device=None,
operations=comfy.ops
):
super().__init__(
channels,
emb_channels,
dropout,
out_channels=out_channels,
use_conv=use_conv,
use_scale_shift_norm=use_scale_shift_norm,
dims=dims,
use_checkpoint=use_checkpoint,
up=up,
down=down,
dtype=dtype,
device=device,
operations=operations
)
self.time_stack = ResBlock(
default(out_channels, channels),
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=use_checkpoint,
exchange_temb_dims=True,
dtype=dtype,
device=device,
operations=operations
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator = None,
) -> th.Tensor:
x = super().forward(x, emb)
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = self.time_stack(
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
)
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
@ -251,6 +356,15 @@ class Timestep(nn.Module):
def forward(self, t):
return timestep_embedding(t, self.dim)
def apply_control(h, control, name):
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
try:
h += ctrl
except:
print("warning control could not be applied", h.shape, ctrl.shape)
return h
class UNetModel(nn.Module):
"""
@ -259,10 +373,6 @@ class UNetModel(nn.Module):
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
@ -289,7 +399,6 @@ class UNetModel(nn.Module):
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
@ -314,6 +423,17 @@ class UNetModel(nn.Module):
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
use_temporal_resblock=False,
use_temporal_attention=False,
time_context_dim=None,
extra_ff_mix_layer=False,
use_spatial_context=False,
merge_strategy=None,
merge_factor=0.0,
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None,
operations=comfy.ops,
):
@ -341,10 +461,7 @@ class UNetModel(nn.Module):
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
@ -352,18 +469,16 @@ class UNetModel(nn.Module):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
transformer_depth = transformer_depth[:]
transformer_depth_output = transformer_depth_output[:]
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
@ -373,8 +488,12 @@ class UNetModel(nn.Module):
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.use_temporal_resblocks = use_temporal_resblock
self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
@ -411,13 +530,104 @@ class UNetModel(nn.Module):
input_block_chans = [model_channels]
ch = model_channels
ds = 1
def get_attention_layer(
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disable_self_attn=False,
):
if use_temporal_attention:
return SpatialVideoTransformer(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=merge_strategy,
merge_factor=merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(
merge_factor,
merge_strategy,
video_kernel_size,
ch,
time_embed_dim,
dropout,
out_channels,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
dtype=None,
device=None,
operations=comfy.ops
):
if self.use_temporal_resblocks:
return VideoResBlock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
else:
return ResBlock(
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
use_checkpoint=use_checkpoint,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
@ -428,7 +638,8 @@ class UNetModel(nn.Module):
)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
@ -443,11 +654,9 @@ class UNetModel(nn.Module):
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
layers.append(get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
@ -456,10 +665,13 @@ class UNetModel(nn.Module):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
@ -488,35 +700,43 @@ class UNetModel(nn.Module):
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
mid_block = [
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
SpatialTransformer( # always uses a self-attn
)]
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
)
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
@ -524,10 +744,13 @@ class UNetModel(nn.Module):
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch + ich,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
@ -538,7 +761,8 @@ class UNetModel(nn.Module):
)
]
ch = model_channels * mult
if ds in attention_resolutions:
num_transformers = transformer_depth_output.pop()
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
@ -554,19 +778,21 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
@ -605,9 +831,13 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
@ -622,26 +852,28 @@ class UNetModel(nn.Module):
h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
ctrl = control['middle'].pop()
if ctrl is not None:
h += ctrl
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
if control is not None and 'output' in control and len(control['output']) > 0:
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
hsp = apply_control(hsp, control, 'output')
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
@ -654,7 +886,7 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)

View File

@ -13,11 +13,78 @@ import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from einops import repeat, rearrange
from comfy.ldm.util import instantiate_from_config
import comfy.ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
else:
raise NotImplementedError()
return alpha
def forward(
self,
x_spatial,
x_temporal,
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
@ -170,8 +237,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:

View File

@ -83,7 +83,8 @@ def _summarize_chunk(
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
torch.exp(attn_weights - max_score, out=attn_weights)
attn_weights -= max_score
torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1)

View File

@ -0,0 +1,244 @@
import functools
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
import comfy.ops
from .diffusionmodules.model import (
AttnBlock,
Decoder,
ResnetBlock,
)
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
from .attention import BasicTransformerBlock
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy="learned",
**kwargs,
):
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, bs):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
b, c, h, w = x.shape
if timesteps is None:
timesteps = b
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps=None, skip_video=False):
if timesteps is None:
timesteps = input.shape[0]
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, "b c t h w -> (b t) c h w")
class AttnVideoBlock(AttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = BasicTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
comfy.ops.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
comfy.ops.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps=None, skip_time_block=False):
if skip_time_block:
return super().forward(x)
if timesteps is None:
timesteps = x.shape[0]
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
def make_time_attn(
in_channels,
attn_type="vanilla",
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
)
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ["all", "conv-only", "attn-only"]
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = "learned",
time_mode: str = "conv-only",
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f"time_mode parameter has to be in {self.available_time_modes}"
if self.time_mode != "attn-only":
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
if self.time_mode not in ["conv-only", "only-last-conv"]:
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
if self.time_mode not in ["attn-only", "only-last-conv"]:
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == "attn-only":
raise NotImplementedError("TODO")
else:
return (
self.conv_out.time_mix_conv.weight
if not skip_time_mix
else self.conv_out.weight
)

View File

@ -131,6 +131,18 @@ def load_lora(lora, to_load):
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
loaded_keys.add(diff_bias_name)
for x in lora.keys():
if x not in loaded_keys:
print("lora key not loaded", x)
@ -141,9 +153,9 @@ def model_lora_keys_clip(model, key_map={}):
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
for b in range(32):
for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP:
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
@ -154,6 +166,8 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
clip_l_present = True

View File

@ -1,16 +1,37 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import numpy as np
import comfy.conds
from enum import Enum
from . import utils
class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
V_PREDICTION_EDM = 3
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
def model_sampling(model_config, model_type):
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
@ -19,10 +40,12 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
@ -30,38 +53,25 @@ class BaseModel(torch.nn.Module):
print("model_type", model_type.name)
print("adm", self.adm_channels)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1)
else:
xc = x
xc = torch.cat([xc] + [c_concat], dim=1)
context = c_crossattn
dtype = self.get_dtype()
xc = xc.to(dtype)
t = t.to(dtype)
t = self.model_sampling.timestep(t).float()
context = context.to(dtype)
if c_adm is not None:
c_adm = c_adm.to(dtype)
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "to"):
extra = extra.to(dtype)
extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def get_dtype(self):
return self.diffusion_model.dtype
@ -72,7 +82,8 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs):
return None
def cond_concat(self, **kwargs):
def extra_conds(self, **kwargs):
out = {}
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
@ -101,8 +112,12 @@ class BaseModel(torch.nn.Module):
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
return cond_concat
return None
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
return out
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
@ -111,6 +126,7 @@ class BaseModel(torch.nn.Module):
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)
@ -147,6 +163,17 @@ class BaseModel(torch.nn.Module):
def set_inpaint(self):
self.inpaint_model = True
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
#TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3]
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
adm_inputs = []
weights = []
@ -241,3 +268,48 @@ class SDXL(BaseModel):
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
fps_id = kwargs.get("fps", 6) - 1
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.Tensor([fps_id])))
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
out.append(self.embedder(torch.Tensor([augmentation])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
def extra_conds(self, **kwargs):
out = {}
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out

View File

@ -14,6 +14,20 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None
def detect_unet_config(state_dict, key_prefix, dtype):
state_dict_keys = list(state_dict.keys())
@ -40,72 +54,95 @@ def detect_unet_config(state_dict, key_prefix, dtype):
channel_mult = []
attention_resolutions = []
transformer_depth = []
transformer_depth_output = []
context_dim = None
use_linear_in_transformer = False
video_model = False
current_res = 1
count = 0
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
while True:
input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.')
for count in range(input_block_count):
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1)
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
if len(block_keys) == 0:
break
block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
if "{}0.op.weight".format(prefix) in block_keys: #new layer
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
current_res *= 2
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
if out is not None:
transformer_depth_output.append(out[0])
else:
transformer_depth_output.append(0)
else:
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
if res_block_prefix in block_keys:
last_res_blocks += 1
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
if context_dim is None:
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
if out is not None:
transformer_depth.append(out[0])
if context_dim is None:
context_dim = out[1]
use_linear_in_transformer = out[2]
video_model = out[3]
else:
transformer_depth.append(0)
res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
if res_block_prefix in block_keys_output:
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
if out is not None:
transformer_depth_output.append(out[0])
else:
transformer_depth_output.append(0)
count += 1
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
if len(set(num_res_blocks)) == 1:
num_res_blocks = num_res_blocks[0]
if len(set(transformer_depth)) == 1:
transformer_depth = transformer_depth[0]
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
else:
transformer_depth_middle = -1
unet_config["in_channels"] = in_channels
unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks
unet_config["attention_resolutions"] = attention_resolutions
unet_config["transformer_depth"] = transformer_depth
unet_config["transformer_depth_output"] = transformer_depth_output
unet_config["channel_mult"] = channel_mult
unet_config["transformer_depth_middle"] = transformer_depth_middle
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
unet_config["context_dim"] = context_dim
if video_model:
unet_config["extra_ff_mix_layer"] = True
unet_config["use_spatial_context"] = True
unet_config["merge_strategy"] = "learned_with_images"
unet_config["merge_factor"] = 0.0
unet_config["video_kernel_size"] = [3, 1, 1]
unet_config["use_temporal_resblock"] = True
unet_config["use_temporal_attention"] = True
else:
unet_config["use_temporal_resblock"] = False
unet_config["use_temporal_attention"] = False
return unet_config
def model_config_from_unet_config(unet_config):
@ -124,19 +161,65 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
else:
return model_config
def convert_config(unet_config):
new_config = unet_config.copy()
num_res_blocks = new_config.get("num_res_blocks", None)
channel_mult = new_config.get("channel_mult", None)
if isinstance(num_res_blocks, int):
num_res_blocks = len(channel_mult) * [num_res_blocks]
if "attention_resolutions" in new_config:
attention_resolutions = new_config.pop("attention_resolutions")
transformer_depth = new_config.get("transformer_depth", None)
transformer_depth_middle = new_config.get("transformer_depth_middle", None)
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
t_in = []
t_out = []
s = 1
for i in range(len(num_res_blocks)):
res = num_res_blocks[i]
d = 0
if s in attention_resolutions:
d = transformer_depth[i]
t_in += [d] * res
t_out += [d] * (res + 1)
s *= 2
transformer_depth = t_in
transformer_depth_output = t_out
new_config["transformer_depth"] = t_in
new_config["transformer_depth_output"] = t_out
new_config["transformer_depth_middle"] = transformer_depth_middle
new_config["num_res_blocks"] = num_res_blocks
return new_config
def unet_config_from_diffusers_unet(state_dict, dtype):
match = {}
attention_resolutions = []
transformer_depth = []
attn_res = 1
for i in range(5):
k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i)
if k in state_dict:
match["context_dim"] = state_dict[k].shape[1]
attention_resolutions.append(attn_res)
attn_res *= 2
down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks):
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
for ab in range(attn_blocks):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
match["attention_resolutions"] = attention_resolutions
attn_res *= 2
if attn_blocks == 0:
transformer_depth.append(0)
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
@ -148,50 +231,65 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint]
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
for unet_config in supported_models:
matches = True
@ -200,7 +298,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
matches = False
break
if matches:
return unet_config
return convert_config(unet_config)
return None
def model_config_from_diffusers_unet(state_dict, dtype):

View File

@ -133,6 +133,10 @@ else:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
try:
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
except:
pass
try:
XFORMERS_VERSION = xformers.version.__version__
print("xformers version:", XFORMERS_VERSION)
@ -478,6 +482,21 @@ def text_encoder_device():
else:
return torch.device("cpu")
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
elif args.fp8_e5m2_text_enc:
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.fp32_text_enc:
return torch.float32
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
def vae_device():
return get_torch_device()
@ -579,27 +598,6 @@ def get_free_memory(dev=None, torch_free_too=False):
else:
return mem_free_total
def batch_area_memory(area):
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: these formulas are copied from maximum_batch_area below
return (area / 20) * (1024 * 1024)
else:
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def maximum_batch_area():
global vram_state
if vram_state == VRAMState.NO_VRAM:
return 0
memory_free = get_free_memory() / (1024 * 1024)
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: this needs to be tweaked
area = 20 * memory_free
else:
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
area = ((memory_free - 1024) * 0.9) / (0.6)
return int(max(area, 0))
def cpu_mode():
global cpu_state
return cpu_state == CPUState.CPU

View File

@ -6,11 +6,13 @@ import comfy.utils
import comfy.model_management
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
@ -20,6 +22,8 @@ class ModelPatcher:
else:
self.current_device = current_device
self.weight_inplace_update = weight_inplace_update
def model_size(self):
if self.size > 0:
return self.size
@ -33,11 +37,12 @@ class ModelPatcher:
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
@ -47,6 +52,9 @@ class ModelPatcher:
return True
return False
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
@ -88,9 +96,18 @@ class ModelPatcher:
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch):
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch):
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
@ -107,10 +124,10 @@ class ModelPatcher:
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
if "unet_wrapper_function" in self.model_options:
wrap_func = self.model_options["unet_wrapper_function"]
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "to"):
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
@ -128,6 +145,7 @@ class ModelPatcher:
return list(p)
def get_key_patches(self, filter_prefix=None):
comfy.model_management.unload_model_clones(self)
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
@ -150,6 +168,12 @@ class ModelPatcher:
return sd
def patch_model(self, device_to=None):
for k in self.object_patches:
old = getattr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
@ -158,15 +182,20 @@ class ModelPatcher:
weight = model_sd[key]
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(self.model, key, out_weight)
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
@ -282,11 +311,21 @@ class ModelPatcher:
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
if self.weight_inplace_update:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:
setattr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}

129
comfy/model_sampling.py Normal file
View File

@ -0,0 +1,129 @@
import torch
import numpy as np
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
beta_schedule = "linear"
if model_config is not None:
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep):
t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
def set_sigma_range(self, sigma_min, sigma_max):
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return 0.25 * sigma.log()
def sigma(self, timestep):
return (timestep / 0.25).exp()
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)

View File

@ -1,29 +1,23 @@
import torch
from contextlib import contextmanager
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
def forward(self, input):
return torch.nn.functional.linear(input, self.weight, self.bias)
class Linear(torch.nn.Linear):
def reset_parameters(self):
return None
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class Conv3d(torch.nn.Conv3d):
def reset_parameters(self):
return None
def conv_nd(dims, *args, **kwargs):
if dims == 2:
return Conv2d(*args, **kwargs)
elif dims == 3:
return Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@ -1,6 +1,7 @@
import torch
import comfy.model_management
import comfy.samplers
import comfy.conds
import comfy.utils
import math
import numpy as np
@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device):
noise_mask = noise_mask.to(device)
return noise_mask
def broadcast_cond(cond, batch, device):
"""broadcasts conditioning to the batch size"""
copy = []
for p in cond:
t = comfy.utils.repeat_to_batch_size(p[0], batch)
t = t.to(device)
copy += [[t] + p[1:]]
return copy
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c[1]:
models += [c[1][model_type]]
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0])
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
@ -72,18 +75,18 @@ def cleanup_additional_models(models):
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device
positive = convert_cond(positive)
negative = convert_cond(negative)
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
positive_copy = broadcast_cond(positive, noise_shape[0], device)
negative_copy = broadcast_cond(negative, noise_shape[0], device)
return real_model, positive_copy, negative_copy, noise_mask, models
return real_model, positive, negative, noise_mask, models
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):

View File

@ -1,48 +1,42 @@
from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc
import torch
import enum
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
from comfy import model_base
import comfy.utils
import comfy.conds
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(cond, x_in, timestep_in):
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in cond[1]:
timestep_start = cond[1]['timestep_start']
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in cond[1]:
timestep_end = cond[1]['timestep_end']
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in cond[1]:
area = cond[1]['area']
if 'strength' in cond[1]:
strength = cond[1]['strength']
adm_cond = None
if 'adm_encoded' in cond[1]:
adm_cond = cond[1]['adm_encoded']
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in cond[1]:
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in cond[1]:
mask_strength = cond[1]["mask_strength"]
mask = cond[1]['mask']
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
@ -51,7 +45,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in cond[1]:
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
@ -67,27 +61,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
conditionning['c_crossattn'] = cond[0]
if 'concat' in cond[1]:
cond_concat_in = cond[1]['concat']
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = []
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
model_conds = conds["model_conds"]
for c in model_conds:
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None
if 'control' in cond[1]:
control = cond[1]['control']
if 'control' in conds:
control = conds['control']
patches = None
if 'gligen' in cond[1]:
gligen = cond[1]['gligen']
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
@ -105,22 +89,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return True
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
s1 = c1['c_crossattn'].shape
s2 = c2['c_crossattn'].shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
if 'c_adm' in c1:
if c1['c_adm'].shape != c2['c_adm'].shape:
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
@ -149,39 +119,27 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
c_concat = []
c_adm = []
crossattn_max_len = 0
for x in c_list:
if 'c_crossattn' in x:
c = x['c_crossattn']
if crossattn_max_len == 0:
crossattn_max_len = c.shape[1]
else:
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
c_crossattn.append(c)
if 'c_concat' in x:
c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {}
c_crossattn_out = []
for c in c_crossattn:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
c_crossattn_out.append(c)
if len(c_crossattn_out) > 0:
out['c_crossattn'] = torch.cat(c_crossattn_out)
if len(c_concat) > 0:
out['c_concat'] = torch.cat(c_concat)
if len(c_adm) > 0:
out['c_adm'] = torch.cat(c_adm)
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in)/100000.0
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
@ -212,9 +170,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area):
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
@ -260,12 +220,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
@ -281,39 +243,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
max_total_area = model_management.maximum_batch_area()
if math.isclose(cond_scale, 1.0):
uncond = None
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
return model_options["sampler_cfg_function"](args)
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args)
else:
return uncond + (cond - uncond) * cond_scale
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond, **kwargs)
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
@ -332,32 +283,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
return out
def simple_scheduler(model, steps):
s = model.model_sampling
sigs = []
ss = len(model.sigmas) / steps
ss = len(s.sigmas) / steps
for x in range(steps):
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
sigs += [0.0]
return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps):
s = model.model_sampling
sigs = []
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
for x in range(len(ddim_timesteps) - 1, -1, -1):
ts = ddim_timesteps[x]
if ts > 999:
ts = 999
sigs.append(model.t_to_sigma(torch.tensor(ts)))
ss = len(s.sigmas) // steps
x = 1
while x < len(s.sigmas):
sigs += [float(s.sigmas[x])]
x += ss
sigs = sigs[::-1]
sigs += [0.0]
return torch.FloatTensor(sigs)
def sgm_scheduler(model, steps):
def normal_scheduler(model, steps, sgm=False, floor=False):
s = model.model_sampling
start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min)
if sgm:
timesteps = torch.linspace(start, end, steps + 1)[:-1]
else:
timesteps = torch.linspace(start, end, steps)
sigs = []
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
for x in range(len(timesteps)):
ts = timesteps[x]
if ts > 999:
ts = 999
sigs.append(model.t_to_sigma(torch.tensor(ts)))
sigs.append(s.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
@ -389,19 +348,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
if 'area' in c[1]:
area = c[1]['area']
if 'area' in c:
area = c['area']
if area[0] == "percentage":
modified = c[1].copy()
modified = c.copy()
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
modified['area'] = area
c = [c[0], modified]
c = modified
conditions[i] = c
if 'mask' in c[1]:
mask = c[1]['mask']
if 'mask' in c:
mask = c['mask']
mask = mask.to(device=device)
modified = c[1].copy()
modified = c.copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[1] != h or mask.shape[2] != w:
@ -422,66 +381,70 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
modified['area'] = area
modified['mask'] = mask
conditions[i] = [c[0], modified]
conditions[i] = modified
def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]:
if 'area' not in c:
return
c_area = c[1]['area']
c_area = c['area']
smallest = None
for x in conds:
if 'area' in x[1]:
a = x[1]['area']
if 'area' in x:
a = x['area']
if c_area[2] >= a[2] and c_area[3] >= a[3]:
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None:
smallest = x
elif 'area' not in smallest[1]:
elif 'area' not in smallest:
smallest = x
else:
if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]:
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
smallest = x
else:
if smallest is None:
smallest = x
if smallest is None:
return
if 'area' in smallest[1]:
if smallest[1]['area'] == c_area:
if 'area' in smallest:
if smallest['area'] == c_area:
return
n = c[1].copy()
conds += [[smallest[0], n]]
out = c.copy()
out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
conds += [out]
def calculate_start_end_timesteps(model, conds):
s = model.model_sampling
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
if 'start_percent' in x[1]:
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
if 'end_percent' in x[1]:
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
if 'start_percent' in x:
timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x:
timestep_end = s.percent_to_sigma(x['end_percent'])
if (timestep_start is not None) or (timestep_end is not None):
n = x[1].copy()
n = x.copy()
if (timestep_start is not None):
n['timestep_start'] = timestep_start
if (timestep_end is not None):
n['timestep_end'] = timestep_end
conds[t] = [x[0], n]
conds[t] = n
def pre_run_control(model, conds):
s = model.model_sampling
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
if 'control' in x[1]:
x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@ -490,16 +453,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond_other = []
for t in range(len(conds)):
x = conds[t]
if 'area' not in x[1]:
if name in x[1] and x[1][name] is not None:
cond_cnets.append(x[1][name])
if 'area' not in x:
if name in x and x[name] is not None:
cond_cnets.append(x[name])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if 'area' not in x[1]:
if name in x[1] and x[1][name] is not None:
uncond_cnets.append(x[1][name])
if 'area' not in x:
if name in x and x[name] is not None:
uncond_cnets.append(x[name])
else:
uncond_other.append((x, t))
@ -509,47 +472,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if name in o[1] and o[1][name] is not None:
n = o[1].copy()
if name in o and o[name] is not None:
n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond += [[o[0], n]]
uncond += [n]
else:
n = o[1].copy()
n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n]
uncond[temp[1]] = n
def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
for t in range(len(conds)):
x = conds[t]
adm_out = None
if 'adm' in x[1]:
adm_out = x[1]["adm"]
else:
params = x[1].copy()
params["width"] = params.get("width", width * 8)
params["height"] = params.get("height", height * 8)
params["prompt_type"] = params.get("prompt_type", prompt_type)
adm_out = model.encode_adm(device=device, **params)
if adm_out is not None:
x[1] = x[1].copy()
x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device)
return conds
def encode_cond(model_function, key, conds, device, **kwargs):
for t in range(len(conds)):
x = conds[t]
params = x[1].copy()
params = x.copy()
params["device"] = device
params["noise"] = noise
params["width"] = params.get("width", noise.shape[3] * 8)
params["height"] = params.get("height", noise.shape[2] * 8)
params["prompt_type"] = params.get("prompt_type", prompt_type)
for k in kwargs:
if k not in params:
params[k] = kwargs[k]
out = model_function(**params)
if out is not None:
x[1] = x[1].copy()
x[1][key] = out
x = x.copy()
model_conds = x['model_conds'].copy()
for k in out:
model_conds[k] = out[k]
x['model_conds'] = model_conds
conds[t] = x
return conds
class Sampler:
@ -557,95 +508,79 @@ class Sampler:
pass
def max_denoise(self, model_wrap, sigmas):
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
class DDIM(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
max_denoise = self.max_denoise(model_wrap, sigmas)
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
return samples
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
class UNIPCBH2(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
def ksampler(sampler_name, extra_options={}):
class KSAMPLER(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
self.sampler_function = sampler_function
self.extra_options = extra_options
self.inpaint_options = inpaint_options
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
if self.max_denoise(model_wrap, sigmas):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
return samples
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
if sampler_name == "dpm_fast":
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
total_steps = len(sigmas) - 1
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive":
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_adaptive_function
else:
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
if latent_image is not None:
noise += latent_image
if sampler_name == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif sampler_name == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options)
return samples
return KSAMPLER
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
if model.model_type == model_base.ModelType.V_PREDICTION:
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
else:
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
return model_wrap
return model_denoise
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
positive = positive[:]
@ -656,8 +591,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
model_wrap = wrap_model(model)
calculate_start_end_timesteps(model_wrap, negative)
calculate_start_end_timesteps(model_wrap, positive)
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
#make sure each cond area has an opposite one with the same area
for c in positive:
@ -665,21 +600,17 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
for c in negative:
create_cond_with_same_area_if_none(positive, c)
pre_run_control(model_wrap, negative + positive)
pre_run_control(model, negative + positive)
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if model.is_adm():
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
if hasattr(model, 'cond_concat'):
positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
@ -690,30 +621,29 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps):
model_wrap = wrap_model(model)
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "normal":
sigmas = model_wrap.get_sigmas(steps)
sigmas = normal_scheduler(model, steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model_wrap, steps)
sigmas = simple_scheduler(model, steps)
elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model_wrap, steps)
sigmas = ddim_scheduler(model, steps)
elif scheduler_name == "sgm_uniform":
sigmas = sgm_scheduler(model_wrap, steps)
sigmas = normal_scheduler(model, steps, sgm=True)
else:
print("error invalid scheduler", self.scheduler)
return sigmas
def sampler_class(name):
def sampler_object(name):
if name == "uni_pc":
sampler = UNIPC
sampler = UNIPC()
elif name == "uni_pc_bh2":
sampler = UNIPCBH2
sampler = UNIPCBH2()
elif name == "ddim":
sampler = DDIM
sampler = ksampler("euler", inpaint_options={"random": True})
else:
sampler = ksampler(name)
return sampler
@ -776,6 +706,6 @@ class KSampler:
else:
return torch.zeros_like(noise)
sampler = sampler_class(self.sampler)
sampler = sampler_object(self.sampler)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)

View File

@ -23,6 +23,7 @@ import comfy.model_patcher
import comfy.lora
import comfy.t2i_adapter.adapter
import comfy.supported_models_base
import comfy.taesd.taesd
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
@ -55,13 +56,26 @@ def load_clip_weights(model, sd):
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = comfy.lora.model_lora_keys_unet(model.model)
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
loaded = comfy.lora.load_lora(lora, key_map)
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
@ -82,10 +96,7 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False):
params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
params['dtype'] = model_management.text_encoder_dtype(load_device)
self.cond_stage_model = clip(**(params))
@ -144,10 +155,24 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
decoder_config = encoder_config.copy()
decoder_config["video_kernel_size"] = [3, 1, 1]
decoder_config["alpha"] = 0.0
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
@ -162,10 +187,12 @@ class VAE:
if device is None:
device = model_management.vae_device()
self.device = device
self.offload_device = model_management.vae_offload_device()
offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype()
self.first_stage_model.to(self.vae_dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -194,10 +221,9 @@ class VAE:
return samples
def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device)
try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
model_management.free_memory(memory_used, self.device)
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@ -210,22 +236,19 @@ class VAE:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
self.first_stage_model = self.first_stage_model.to(self.device)
model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1)
def encode(self, pixel_samples):
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@ -238,14 +261,12 @@ class VAE:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
self.first_stage_model = self.first_stage_model.to(self.device)
model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples
def get_sd(self):
@ -360,7 +381,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
from . import latent_formats
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
model_config.unet_config = unet_config
model_config.unet_config = model_detection.convert_config(unet_config)
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
@ -388,11 +409,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model.clip_h
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model.clip_l
load_clip_weights(w, state_dict)
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
@ -429,6 +452,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
if output_clip:
@ -453,20 +477,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return (model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path)
def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return None
new_sd = sd
else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path)
return None
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
@ -481,8 +503,19 @@ def load_unet(unet_path): #load unet in diffusers format
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys in unet:", left_over)
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd)
if model is None:
print("ERROR UNSUPPORTED UNET", unet_path)
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()])
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())

View File

@ -8,34 +8,56 @@ import zipfile
from . import model_management
import contextlib
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
end_token = special_tokens.get("end", None)
pad_token = special_tokens.get("pad")
output = []
if start_token is not None:
output.append(start_token)
if end_token is not None:
output.append(end_token)
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list(self.empty_tokens)
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
z_empty = out[0:1]
if pooled.shape[0] > 1:
first_pooled = pooled[1:2]
if pooled is not None:
first_pooled = pooled[0:1].cpu()
else:
first_pooled = pooled[0:1]
first_pooled = pooled
output = []
for k in range(1, out.shape[0]):
for k in range(0, sections):
z = out[k:k+1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k - 1][j][1]
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
return z_empty.cpu(), first_pooled.cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
@ -43,37 +65,43 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.num_layers = 12
if textmodel_path is not None:
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
self.transformer = model_class.from_pretrained(textmodel_path)
else:
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config)
config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76]
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False
self.layer_norm_hidden_state = True
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers
@ -117,7 +145,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x):
tokens_temp += [self.empty_tokens[0][-1]]
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]
n = token_dict_size
@ -142,12 +170,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device), torch.float32):
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
@ -168,12 +196,16 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
pooled_output = outputs.pooler_output
if self.text_projection is not None:
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output.float()
return z.float(), pooled_output
def encode(self, tokens):
return self(tokens)
@ -278,7 +310,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
valid_file = None
for embed_dir in embedding_directory:
embed_path = os.path.join(embed_dir, embedding_name)
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
embed_dir = os.path.abspath(embed_dir)
try:
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
continue
except:
continue
if not os.path.isfile(embed_path):
extensions = ['.safetensors', '.pt', '.bin']
for x in extensions:
@ -336,18 +374,25 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = next(iter(values))
return embed_out
class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.max_tokens_per_section = self.max_length - 2
empty = self.tokenizer('')["input_ids"]
self.start_token = empty[0]
self.end_token = empty[1]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
@ -408,11 +453,13 @@ class SD1Tokenizer:
else:
continue
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
#reshape token array to CLIP input size
batched_tokens = []
batch = [(self.start_token, 1.0, 0)]
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
@ -429,16 +476,21 @@ class SD1Tokenizer:
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = [(self.start_token, 1.0, 0)]
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
#fill last batch
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
@ -448,3 +500,40 @@ class SD1Tokenizer:
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
class SD1Tokenizer:
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair)
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs):
super().__init__()
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
def clip_layer(self, layer_idx):
getattr(self, self.clip).clip_layer(layer_idx)
def reset_clip_layer(self):
getattr(self, self.clip).reset_clip_layer()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out, pooled
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)

View File

@ -2,16 +2,23 @@ from comfy import sd1_clip
import torch
import os
class SD2ClipModel(sd1_clip.SD1ClipModel):
class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=23
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)

View File

@ -2,28 +2,27 @@ from comfy import sd1_clip
import torch
import os
class SDXLClipG(sd1_clip.SD1ClipModel):
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.layer_norm_hidden_state = False
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd):
return super().load_sd(sd)
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class SDXLTokenizer(sd1_clip.SD1Tokenizer):
class SDXLTokenizer:
def __init__(self, embedding_directory=None):
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@ -38,8 +37,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
self.clip_l.layer_norm_hidden_state = False
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):
@ -63,21 +61,6 @@ class SDXLClipModel(torch.nn.Module):
else:
return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):
self.clip_g.clip_layer(layer_idx)
def reset_clip_layer(self):
self.clip_g.reset_clip_layer()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
return g_out, g_pooled
def load_sd(self, sd):
return self.clip_g.load_sd(sd)
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)

View File

@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
@ -38,8 +39,15 @@ class SD15(supported_models_base.BASE):
if ids.dtype == torch.float32:
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
replace_prefix = {}
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
@ -49,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
latent_format = latent_formats.SD15
@ -62,12 +71,12 @@ class SD20(supported_models_base.BASE):
return model_base.ModelType.EPS
def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix[""] = "cond_stage_model.model."
replace_prefix["clip_h"] = "cond_stage_model.model"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
@ -81,6 +90,7 @@ class SD21UnclipL(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 1536,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@ -93,6 +103,7 @@ class SD21UnclipH(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 2048,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@ -104,7 +115,8 @@ class SDXLRefiner(supported_models_base.BASE):
"use_linear_in_transformer": True,
"context_dim": 1280,
"adm_in_channels": 2560,
"transformer_depth": [0, 4, 4, 0],
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@ -139,9 +151,10 @@ class SDXL(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 10],
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@ -165,6 +178,7 @@ class SDXL(supported_models_base.BASE):
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
@ -189,5 +203,40 @@ class SDXL(supported_models_base.BASE):
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
class SSD1B(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 4, 4],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL]
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 768,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self):
return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
models += [SVD_img2vid]

View File

@ -19,7 +19,7 @@ class BASE:
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
beta_schedule = "linear"
sampling_settings = {}
latent_format = latent_formats.LatentFormat
@classmethod
@ -53,6 +53,12 @@ class BASE:
def process_clip_state_dict(self, state_dict):
return state_dict
def process_unet_state_dict(self, state_dict):
return state_dict
def process_vae_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
def __init__(self, encoder_path=None, decoder_path=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.taesd_encoder = Encoder()
self.taesd_decoder = Decoder()
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None:
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod
def scale_latents(x):
@ -65,3 +66,11 @@ class TAESD(nn.Module):
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale

View File

@ -170,25 +170,12 @@ UNET_MAP_BASIC = {
def unet_to_diffusers(unet_config):
num_res_blocks = unet_config["num_res_blocks"]
attention_resolutions = unet_config["attention_resolutions"]
channel_mult = unet_config["channel_mult"]
transformer_depth = unet_config["transformer_depth"]
transformer_depth = unet_config["transformer_depth"][:]
transformer_depth_output = unet_config["transformer_depth_output"][:]
num_blocks = len(channel_mult)
if isinstance(num_res_blocks, int):
num_res_blocks = [num_res_blocks] * num_blocks
if isinstance(transformer_depth, int):
transformer_depth = [transformer_depth] * num_blocks
transformers_per_layer = []
res = 1
for i in range(num_blocks):
transformers = 0
if res in attention_resolutions:
transformers = transformer_depth[i]
transformers_per_layer.append(transformers)
res *= 2
transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
transformers_mid = unet_config.get("transformer_depth_middle", None)
diffusers_unet_map = {}
for x in range(num_blocks):
@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config):
for i in range(num_res_blocks[x]):
for b in UNET_MAP_RESNET:
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
if transformers_per_layer[x] > 0:
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
for t in range(transformers_per_layer[x]):
for t in range(num_transformers):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
n += 1
@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config):
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
num_res_blocks = list(reversed(num_res_blocks))
transformers_per_layer = list(reversed(transformers_per_layer))
for x in range(num_blocks):
n = (num_res_blocks[x] + 1) * x
l = num_res_blocks[x] + 1
@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config):
for b in UNET_MAP_RESNET:
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
c += 1
if transformers_per_layer[x] > 0:
num_transformers = transformer_depth_output.pop()
if num_transformers > 0:
c += 1
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
for t in range(transformers_per_layer[x]):
for t in range(num_transformers):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
if i == l - 1:
@ -270,9 +258,17 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
@ -311,23 +307,25 @@ def bislerp(samples, width, height):
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new):
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
orig_dtype = samples.dtype
samples = samples.float()
n,c,h,w = samples.shape
h_new, w_new = (height, width)
#linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
@ -340,7 +338,7 @@ def bislerp(samples, width, height):
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
#linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
@ -351,7 +349,7 @@ def bislerp(samples, width, height):
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result
return result.to(orig_dtype)
def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]

View File

@ -16,7 +16,7 @@ class BasicScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -36,7 +36,7 @@ class KarrasScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -54,7 +54,7 @@ class ExponentialScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -73,7 +73,7 @@ class PolyexponentialScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -81,6 +81,25 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, )
class SDTurboScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
sigmas = model.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
class VPScheduler:
@classmethod
def INPUT_TYPES(s):
@ -92,7 +111,7 @@ class VPScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -109,7 +128,7 @@ class SplitSigmas:
}
}
RETURN_TYPES = ("SIGMAS","SIGMAS")
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigmas"
@ -118,6 +137,24 @@ class SplitSigmas:
sigmas2 = sigmas[step:]
return (sigmas1, sigmas2)
class FlipSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigmas"
def get_sigmas(self, sigmas):
sigmas = sigmas.flip(0)
if sigmas[0] == 0:
sigmas[0] = 0.0001
return (sigmas,)
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@ -126,12 +163,12 @@ class KSamplerSelect:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_class(sampler_name)()
sampler = comfy.samplers.sampler_object(sampler_name)
return (sampler, )
class SamplerDPMPP_2M_SDE:
@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name = "dpmpp_2m_sde"
else:
sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})()
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
return (sampler, )
@ -169,7 +206,7 @@ class SamplerDPMPP_SDE:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
@ -178,7 +215,7 @@ class SamplerDPMPP_SDE:
sampler_name = "dpmpp_sde"
else:
sampler_name = "dpmpp_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})()
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )
class SamplerCustom:
@ -188,7 +225,7 @@ class SamplerCustom:
{"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"sampler": ("SAMPLER", ),
@ -234,13 +271,15 @@ class SamplerCustom:
NODE_CLASS_MAPPINGS = {
"SamplerCustom": SamplerCustom,
"BasicScheduler": BasicScheduler,
"KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler,
"VPScheduler": VPScheduler,
"SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"BasicScheduler": BasicScheduler,
"SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas,
}

View File

@ -0,0 +1,175 @@
import nodes
import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "crop"
CATEGORY = "image/transform"
def crop(self, image, width, height, x, y):
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return (img,)
class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image/batch"
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1))
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
}

View File

@ -1,4 +1,5 @@
import comfy.utils
import torch
def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
@ -67,8 +68,43 @@ class LatentMultiply:
samples_out["samples"] = s1 * multiplier
return (samples_out,)
class LatentInterpolate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = reshape_latent_to(s1.shape, s2)
m1 = torch.linalg.vector_norm(s1, dim=(1))
m2 = torch.linalg.vector_norm(s2, dim=(1))
s1 = torch.nan_to_num(s1 / m1)
s2 = torch.nan_to_num(s2 / m2)
t = (s1 * ratio + s2 * (1.0 - ratio))
mt = torch.linalg.vector_norm(t, dim=(1))
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
}

View File

@ -0,0 +1,205 @@
import folder_paths
import comfy.sd
import comfy.model_sampling
import torch
class LCM(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
x0 = model_input - model_output * sigma
sigma_data = 0.5
scaled_timestep = timestep * 10.0 #timestep_scaling
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteDistilled(torch.nn.Module):
original_timesteps = 50
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.skip_steps = timesteps // self.original_timesteps
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
def sigma(self, timestep):
t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, zsnr):
m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class RescaleCFG:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, multiplier):
def rescale_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
sigma = args["sigma"]
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
x_orig = args["input"]
#rescale cfg has to be done on v-pred model output
x = x_orig / (sigma * sigma + 1.0)
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
#rescalecfg
x_cfg = uncond + cond_scale * (cond - uncond)
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
x_rescaled = x_cfg * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
m = model.clone()
m.set_model_sampler_cfg_function(rescale_cfg)
return (m, )
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"RescaleCFG": RescaleCFG,
}

View File

@ -0,0 +1,53 @@
import torch
import comfy.utils
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
"downscale_method": (s.upscale_methods,),
"upscale_method": (s.upscale_methods,),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
return h, hsp
m = model.clone()
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
}

View File

@ -23,7 +23,7 @@ class Blend:
"max": 1.0,
"step": 0.01
}),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
},
}
@ -54,6 +54,8 @@ class Blend:
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
elif mode == "difference":
return img1 - img2
else:
raise ValueError(f"Unsupported blend mode: {mode}")
@ -126,7 +128,7 @@ class Quantize:
"max": 256,
"step": 1
}),
"dither": (["none", "floyd-steinberg"],),
"dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
},
}
@ -135,19 +137,47 @@ class Quantize:
CATEGORY = "image/postprocessing"
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
def bayer(im, pal_im, order):
def normalized_bayer_matrix(n):
if n == 0:
return np.zeros((1,1), "float32")
else:
q = 4 ** n
m = q * normalized_bayer_matrix(n - 1)
return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
num_colors = len(pal_im.getpalette()) // 3
spread = 2 * 256 / num_colors
bayer_n = int(math.log2(order))
bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
result = torch.from_numpy(np.array(im).astype(np.float32))
tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
result = result.to(dtype=torch.uint8)
im = Image.fromarray(result.cpu().numpy())
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
return im
def quantize(self, image: torch.Tensor, colors: int, dither: str):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
if dither == "none":
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
elif dither == "floyd-steinberg":
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
elif dither.startswith("bayer"):
order = int(dither.split('-')[-1])
quantized_image = Quantize.bayer(im, pal_im, order)
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array

View File

@ -4,7 +4,7 @@ class LatentRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True

View File

@ -0,0 +1,89 @@
import nodes
import torch
import comfy.utils
import comfy.sd
import folder_paths
class ImageOnlyCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
class SVD_img2vid_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
if augmentation_level > 0:
encode_pixels += torch.randn_like(pixels) * augmentation_level
t = vae.encode(encode_pixels)
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
class VideoLinearCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}

View File

@ -683,6 +683,7 @@ def validate_prompt(prompt):
return (True, None, list(good_outputs), node_errors)
MAXIMUM_HISTORY_SIZE = 10000
class PromptQueue:
def __init__(self, server):
@ -715,6 +716,8 @@ class PromptQueue:
def task_done(self, item_id, outputs):
with self.mutex:
prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs:
self.history[prompt[1]]["outputs"][o] = outputs[o]
@ -749,10 +752,20 @@ class PromptQueue:
return True
return False
def get_history(self, prompt_id=None):
def get_history(self, prompt_id=None, max_items=None, offset=-1):
with self.mutex:
if prompt_id is None:
return copy.deepcopy(self.history)
out = {}
i = 0
if offset < 0 and max_items is not None:
offset = len(self.history) - max_items
for k in self.history:
if i >= offset:
out[k] = self.history[k]
if max_items is not None and len(out) >= max_items:
break
i += 1
return out
elif prompt_id in self.history:
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
else:

View File

@ -38,7 +38,10 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "inp
filename_list_cache = {}
if not os.path.exists(input_directory):
os.makedirs(input_directory)
try:
os.makedirs(input_directory)
except:
print("Failed to create input directory")
def set_output_directory(output_dir):
global output_directory
@ -227,8 +230,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
err = "**** ERROR: Saving image outside the output folder is not allowed." + \
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \
"\n output_dir: " + output_dir + \
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
print(err)
raise Exception(err)
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1

View File

@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
self.taesd = taesd
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0)[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = self.taesd.decode(x0[:1])[0].detach()
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)

12
main.py
View File

@ -88,6 +88,7 @@ def cuda_malloc_warning():
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
last_gc_collect = 0
while True:
item, item_id = q.get()
execution_start_time = time.perf_counter()
@ -97,9 +98,14 @@ def prompt_worker(q, server):
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect()
comfy.model_management.soft_empty_cache()
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time))
if (current_time - last_gc_collect) > 10.0:
gc.collect()
comfy.model_management.soft_empty_cache()
last_gc_collect = current_time
print("gc collect")
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())

View File

@ -248,8 +248,8 @@ class ConditioningSetTimestepRange:
c = []
for t in conditioning:
d = t[1].copy()
d['start_percent'] = 1.0 - start
d['end_percent'] = 1.0 - end
d['start_percent'] = start
d['end_percent'] = end
n = [t[0], d]
c.append(n)
return (c, )
@ -572,10 +572,69 @@ class LoraLoader:
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora)
class VAELoader:
class LoraLoaderModelOnly(LoraLoader):
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes
@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -583,8 +642,11 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)
@ -685,7 +747,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
@ -1218,7 +1280,7 @@ class KSampler:
{"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
@ -1244,7 +1306,7 @@ class KSamplerAdvanced:
"add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
@ -1275,6 +1337,7 @@ class SaveImage:
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
@classmethod
def INPUT_TYPES(s):
@ -1308,7 +1371,7 @@ class SaveImage:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({
"filename": file,
"subfolder": subfolder,
@ -1323,6 +1386,7 @@ class PreviewImage(SaveImage):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 1
@classmethod
def INPUT_TYPES(s):
@ -1654,6 +1718,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -1759,7 +1824,7 @@ def load_custom_nodes():
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path)
possible_modules = os.listdir(os.path.realpath(custom_node_path))
if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
@ -1799,6 +1864,10 @@ def init_custom_nodes():
"nodes_custom_sampler.py",
"nodes_subflow.py",
"nodes_hypertile.py",
"nodes_model_advanced.py",
"nodes_model_downscale.py",
"nodes_images.py",
"nodes_video_model.py",
]
for node_file in extras_files:

View File

@ -83,7 +83,8 @@ class PromptServer():
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=104857600, middlewares=middlewares)
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict()
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web")
@ -433,7 +434,10 @@ class PromptServer():
@routes.get("/history")
async def get_history(request):
return web.json_response(self.prompt_queue.get_history())
max_items = request.rel_url.query.get("max_items", None)
if max_items is not None:
max_items = int(max_items)
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
@routes.get("/history/{prompt_id}")
async def get_history(request):
@ -590,7 +594,7 @@ class PromptServer():
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)

View File

@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
* @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType
* @param { boolean } hasControlWidget
* @param { number } controlWidgetCount
* @returns
*/
async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) {
async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
// Connect to primitive and ensure its still connected after
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(input);
@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
expect(valueWidget.widget.type).toBe(widgetType);
// Check if control_after_generate should be added
if (hasControlWidget) {
if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo");
if(widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
}
// Ensure we dont have other widgets
expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget);
expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
});
return primitive;
@ -55,8 +59,8 @@ describe("widget inputs", () => {
});
[
{ name: "int", type: "INT", widget: "number", control: true },
{ name: "float", type: "FLOAT", widget: "number", control: true },
{ name: "int", type: "INT", widget: "number", control: 1 },
{ name: "float", type: "FLOAT", widget: "number", control: 1 },
{ name: "text", type: "STRING" },
{
name: "customtext",
@ -64,7 +68,7 @@ describe("widget inputs", () => {
opt: { multiline: true },
},
{ name: "toggle", type: "BOOLEAN" },
{ name: "combo", type: ["a", "b", "c"], control: true },
{ name: "combo", type: ["a", "b", "c"], control: 2 },
].forEach((c) => {
test(`widget conversion + primitive works on ${c.name}`, async () => {
const { ez, graph } = await start({
@ -106,7 +110,7 @@ describe("widget inputs", () => {
n.widgets.ckpt_name.convertToInput();
expect(n.inputs.length).toEqual(inputCount + 1);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
// Disconnect & reconnect
primitive.outputs[0].connections[0].disconnect();
@ -226,7 +230,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget
if (!assertNotNullOrUndefined(input)) return;
await connectPrimitiveAndReload(ez, graph, input, "number", true);
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
w = n.widgets.example;
@ -258,7 +262,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget
if (assertNotNullOrUndefined(input)) {
await connectPrimitiveAndReload(ez, graph, input, "number", true);
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
expect(n.widgets.example.isConvertedToInput).toBeTruthy();
@ -316,4 +320,76 @@ describe("widget inputs", () => {
n1.outputs[0].connectTo(n2.inputs[0]);
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
});
test("combo primitive can filter list when control_after_generate called", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
},
});
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode()
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
const control = p.widgets.control_after_generate.widget;
const filter = p.widgets.control_filter_list;
expect(p.widgets.length).toBe(3);
control.value = "increment";
expect(value.value).toBe("A");
// Manually trigger after queue when set to increment
control["afterQueued"]();
expect(value.value).toBe("B");
// Filter to items containing D
filter.value = "D";
control["afterQueued"]();
expect(value.value).toBe("D");
control["afterQueued"]();
expect(value.value).toBe("DD");
// Check decrement
value.value = "BBB";
control.value = "decrement";
filter.value = "B";
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("B");
// Check regex works
value.value = "BBB";
filter.value = "/[AB]|^C$/";
control["afterQueued"]();
expect(value.value).toBe("AAA");
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("AA");
control["afterQueued"]();
expect(value.value).toBe("C");
control["afterQueued"]();
expect(value.value).toBe("B");
control["afterQueued"]();
expect(value.value).toBe("A");
// Check random
control.value = "randomize";
filter.value = "/D/";
for(let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
// Ensure it doesnt apply when fixed
control.value = "fixed";
value.value = "B";
filter.value = "C";
control["afterQueued"]();
expect(value.value).toBe("B");
});
});

View File

@ -174,6 +174,213 @@ const colorPalettes = {
"tr-odd-bg-color": "#073642",
}
},
},
"arc": {
"id": "arc",
"name": "Arc",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAYAAABw4pVUAAAACXBIWXMAAAsTAAALEwEAmpwYAAABcklEQVR4nO3YMUoDARgF4RfxBqZI6/0vZqFn0MYtrLIQMFN8U6V4LAtD+Jm9XG/v30OGl2e/AP7yevz4+vx45nvgF/+QGITEICQGITEIiUFIjNNC3q43u3/YnRJyPOzeQ+0e220nhRzReC8e7R7bbdvl+Jal1Bs46jEIiUFIDEJiEBKDkBhKPbZT6qHdptRTu02p53DUYxASg5AYhMQgJAYhMZR6bKfUQ7tNqad2m1LP4ajHICQGITEIiUFIDEJiKPXYTqmHdptST+02pZ7DUY9BSAxCYhASg5AYhMRQ6rGdUg/tNqWe2m1KPYejHoOQGITEICQGITEIiaHUYzulHtptSj2125R6Dkc9BiExCIlBSAxCYhASQ6nHdko9tNuUemq3KfUcjnoMQmIQEoOQGITEICSGUo/tlHpotyn11G5T6jkc9RiExCAkBiExCIlBSAylHtsp9dBuU+qp3abUczjqMQiJQUgMQmIQEoOQGITE+AHFISNQrFTGuwAAAABJRU5ErkJggg==",
"CLEAR_BACKGROUND_COLOR": "#2b2f38",
"NODE_TITLE_COLOR": "#b2b7bd",
"NODE_SELECTED_TITLE_COLOR": "#FFF",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#AAA",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#2b2f38",
"NODE_DEFAULT_BGCOLOR": "#242730",
"NODE_DEFAULT_BOXCOLOR": "#6e7581",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#FFF",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 22,
"WIDGET_BGCOLOR": "#2b2f38",
"WIDGET_OUTLINE_COLOR": "#6e7581",
"WIDGET_TEXT_COLOR": "#DDD",
"WIDGET_SECONDARY_TEXT_COLOR": "#b2b7bd",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#fff",
"bg-color": "#2b2f38",
"comfy-menu-bg": "#242730",
"comfy-input-bg": "#2b2f38",
"input-text": "#ddd",
"descrip-text": "#b2b7bd",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#6e7581",
"tr-even-bg-color": "#2b2f38",
"tr-odd-bg-color": "#242730"
}
},
},
"nord": {
"id": "nord",
"name": "Nord",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFu2lUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4gPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iQWRvYmUgWE1QIENvcmUgOS4xLWMwMDEgNzkuMTQ2Mjg5OSwgMjAyMy8wNi8yNS0yMDowMTo1NSAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvIiB4bWxuczpkYz0iaHR0cDovL3B1cmwub3JnL2RjL2VsZW1lbnRzLzEuMS8iIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIgeG1sbnM6eG1wTU09Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9tbS8iIHhtbG5zOnN0RXZ0PSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VFdmVudCMiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIiB4bXA6Q3JlYXRlRGF0ZT0iMjAyMy0xMS0xM1QwMDoxODowMiswMTowMCIgeG1wOk1vZGlmeURhdGU9IjIwMjMtMTEtMTVUMDE6MjA6NDUrMDE6MDAiIHhtcDpNZXRhZGF0YURhdGU9IjIwMjMtMTEtMTVUMDE6MjA6NDUrMDE6MDAiIGRjOmZvcm1hdD0iaW1hZ2UvcG5nIiBwaG90b3Nob3A6Q29sb3JNb2RlPSIzIiB4bXBNTTpJbnN0YW5jZUlEPSJ4bXAuaWlkOjUwNDFhMmZjLTEzNzQtMTk0ZC1hZWY4LTYxMzM1MTVmNjUwMCIgeG1wTU06RG9jdW1lbnRJRD0ieG1wLmRpZDoyMzFiMTBiMC1iNGZiLTAyNGUtYjEyZS0zMDUzMDNjZDA3YzgiIHhtcE1NOk9yaWdpbmFsRG9jdW1lbnRJRD0ieG1wLmRpZDoyMzFiMTBiMC1iNGZiLTAyNGUtYjEyZS0zMDUzMDNjZDA3YzgiPiA8eG1wTU06SGlzdG9yeT4gPHJkZjpTZXE+IDxyZGY6bGkgc3RFdnQ6YWN0aW9uPSJjcmVhdGVkIiBzdEV2dDppbnN0YW5jZUlEPSJ4bXAuaWlkOjIzMWIxMGIwLWI0ZmItMDI0ZS1iMTJlLTMwNTMwM2NkMDdjOCIgc3RFdnQ6d2hlbj0iMjAyMy0xMS0xM1QwMDoxODowMiswMTowMCIgc3RFdnQ6c29mdHdhcmVBZ2VudD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIi8+IDxyZGY6bGkgc3RFdnQ6YWN0aW9uPSJzYXZlZCIgc3RFdnQ6aW5zdGFuY2VJRD0ieG1wLmlpZDo1MDQxYTJmYy0xMzc0LTE5NGQtYWVmOC02MTMzNTE1ZjY1MDAiIHN0RXZ0OndoZW49IjIwMjMtMTEtMTVUMDE6MjA6NDUrMDE6MDAiIHN0RXZ0OnNvZnR3YXJlQWdlbnQ9IkFkb2JlIFBob3Rvc2hvcCAyNS4xIChXaW5kb3dzKSIgc3RFdnQ6Y2hhbmdlZD0iLyIvPiA8L3JkZjpTZXE+IDwveG1wTU06SGlzdG9yeT4gPC9yZGY6RGVzY3JpcHRpb24+IDwvcmRmOlJERj4gPC94OnhtcG1ldGE+IDw/eHBhY2tldCBlbmQ9InIiPz73jWg/AAAAyUlEQVR42u3WKwoAIBRFQRdiMb1idv9Lsxn9gEFw4Dbb8JCTojbbXEJwjJVL2HKwYMGCBQuWLbDmjr+9zrBGjHl1WVcvy2DBggULFizTWQpewSt4HzwsgwULFiwFr7MUvMtS8D54WLBgGSxYCl7BK3iXZbBgwYIFC5bpLAWv4BW8Dx6WwYIFC5aC11kK3mUpeB88LFiwDBYsBa/gFbzLMliwYMGCBct0loJX8AreBw/LYMGCBUvB6ywF77IUvA8eFixYBgsWrNfWAZPltufdad+1AAAAAElFTkSuQmCC",
"CLEAR_BACKGROUND_COLOR": "#212732",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#bcc2c8",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#2e3440",
"NODE_DEFAULT_BGCOLOR": "#161b22",
"NODE_DEFAULT_BOXCOLOR": "#545d70",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 24,
"WIDGET_BGCOLOR": "#2e3440",
"WIDGET_OUTLINE_COLOR": "#545d70",
"WIDGET_TEXT_COLOR": "#bcc2c8",
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#e5eaf0",
"bg-color": "#2e3440",
"comfy-menu-bg": "#161b22",
"comfy-input-bg": "#2e3440",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#545d70",
"tr-even-bg-color": "#2e3440",
"tr-odd-bg-color": "#161b22"
}
},
},
"github": {
"id": "github",
"name": "Github",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGlmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4gPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iQWRvYmUgWE1QIENvcmUgOS4xLWMwMDEgNzkuMTQ2Mjg5OSwgMjAyMy8wNi8yNS0yMDowMTo1NSAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvIiB4bWxuczpkYz0iaHR0cDovL3B1cmwub3JnL2RjL2VsZW1lbnRzLzEuMS8iIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIgeG1sbnM6eG1wTU09Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9tbS8iIHhtbG5zOnN0RXZ0PSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VFdmVudCMiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIiB4bXA6Q3JlYXRlRGF0ZT0iMjAyMy0xMS0xM1QwMDoxODowMiswMTowMCIgeG1wOk1vZGlmeURhdGU9IjIwMjMtMTEtMTVUMDI6MDQ6NTkrMDE6MDAiIHhtcDpNZXRhZGF0YURhdGU9IjIwMjMtMTEtMTVUMDI6MDQ6NTkrMDE6MDAiIGRjOmZvcm1hdD0iaW1hZ2UvcG5nIiBwaG90b3Nob3A6Q29sb3JNb2RlPSIzIiB4bXBNTTpJbnN0YW5jZUlEPSJ4bXAuaWlkOmIyYzRhNjA5LWJmYTctYTg0MC1iOGFlLTk3MzE2ZjM1ZGIyNyIgeG1wTU06RG9jdW1lbnRJRD0iYWRvYmU6ZG9jaWQ6cGhvdG9zaG9wOjk0ZmNlZGU4LTE1MTctZmQ0MC04ZGU3LWYzOTgxM2E3ODk5ZiIgeG1wTU06T3JpZ2luYWxEb2N1bWVudElEPSJ4bXAuZGlkOjIzMWIxMGIwLWI0ZmItMDI0ZS1iMTJlLTMwNTMwM2NkMDdjOCI+IDx4bXBNTTpIaXN0b3J5PiA8cmRmOlNlcT4gPHJkZjpsaSBzdEV2dDphY3Rpb249ImNyZWF0ZWQiIHN0RXZ0Omluc3RhbmNlSUQ9InhtcC5paWQ6MjMxYjEwYjAtYjRmYi0wMjRlLWIxMmUtMzA1MzAzY2QwN2M4IiBzdEV2dDp3aGVuPSIyMDIzLTExLTEzVDAwOjE4OjAyKzAxOjAwIiBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZG9iZSBQaG90b3Nob3AgMjUuMSAoV2luZG93cykiLz4gPHJkZjpsaSBzdEV2dDphY3Rpb249InNhdmVkIiBzdEV2dDppbnN0YW5jZUlEPSJ4bXAuaWlkOjQ4OWY1NzlmLTJkNjUtZWQ0Zi04OTg0LTA4NGE2MGE1ZTMzNSIgc3RFdnQ6d2hlbj0iMjAyMy0xMS0xNVQwMjowNDo1OSswMTowMCIgc3RFdnQ6c29mdHdhcmVBZ2VudD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIiBzdEV2dDpjaGFuZ2VkPSIvIi8+IDxyZGY6bGkgc3RFdnQ6YWN0aW9uPSJzYXZlZCIgc3RFdnQ6aW5zdGFuY2VJRD0ieG1wLmlpZDpiMmM0YTYwOS1iZmE3LWE4NDAtYjhhZS05NzMxNmYzNWRiMjciIHN0RXZ0OndoZW49IjIwMjMtMTEtMTVUMDI6MDQ6NTkrMDE6MDAiIHN0RXZ0OnNvZnR3YXJlQWdlbnQ9IkFkb2JlIFBob3Rvc2hvcCAyNS4xIChXaW5kb3dzKSIgc3RFdnQ6Y2hhbmdlZD0iLyIvPiA8L3JkZjpTZXE+IDwveG1wTU06SGlzdG9yeT4gPC9yZGY6RGVzY3JpcHRpb24+IDwvcmRmOlJERj4gPC94OnhtcG1ldGE+IDw/eHBhY2tldCBlbmQ9InIiPz4OTe6GAAAAx0lEQVR42u3WMQoAIQxFwRzJys77X8vSLiRgITif7bYbgrwYc/mKXyBoY4VVBgsWLFiwYFmOlTv+9jfDOjHmr8u6eVkGCxYsWLBgmc5S8ApewXvgYRksWLBgKXidpeBdloL3wMOCBctgwVLwCl7BuyyDBQsWLFiwTGcpeAWv4D3wsAwWLFiwFLzOUvAuS8F74GHBgmWwYCl4Ba/gXZbBggULFixYprMUvIJX8B54WAYLFixYCl5nKXiXpeA98LBgwTJYsGC9tg1o8f4TTtqzNQAAAABJRU5ErkJggg==",
"CLEAR_BACKGROUND_COLOR": "#040506",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#bcc2c8",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#161b22",
"NODE_DEFAULT_BGCOLOR": "#13171d",
"NODE_DEFAULT_BOXCOLOR": "#30363d",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 24,
"WIDGET_BGCOLOR": "#161b22",
"WIDGET_OUTLINE_COLOR": "#30363d",
"WIDGET_TEXT_COLOR": "#bcc2c8",
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#e5eaf0",
"bg-color": "#161b22",
"comfy-menu-bg": "#13171d",
"comfy-input-bg": "#161b22",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#30363d",
"tr-even-bg-color": "#161b22",
"tr-odd-bg-color": "#13171d"
}
},
}
};

View File

@ -25,7 +25,7 @@ const ext = {
requestAnimationFrame(() => {
const currentNode = LGraphCanvas.active_canvas.current_node;
const clickedComboValue = currentNode.widgets
.filter(w => w.type === "combo" && w.options.values.length === values.length)
?.filter(w => w.type === "combo" && w.options.values.length === values.length)
.find(w => w.options.values.every((v, i) => v === values[i]))
?.value;

View File

@ -14,6 +14,9 @@ import { ComfyDialog, $el } from "../../scripts/ui.js";
// To delete/rename:
// Right click the canvas
// Node templates -> Manage
//
// To rearrange:
// Open the manage dialog and Drag and drop elements using the "Name:" label as handle
const id = "Comfy.NodeTemplates";
@ -22,6 +25,10 @@ class ManageTemplates extends ComfyDialog {
super();
this.element.classList.add("comfy-manage-templates");
this.templates = this.load();
this.draggedEl = null;
this.saveVisualCue = null;
this.emptyImg = new Image();
this.emptyImg.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs=';
this.importInput = $el("input", {
type: "file",
@ -35,14 +42,11 @@ class ManageTemplates extends ComfyDialog {
createButtons() {
const btns = super.createButtons();
btns[0].textContent = "Cancel";
btns.unshift(
$el("button", {
type: "button",
textContent: "Save",
onclick: () => this.save(),
})
);
btns[0].textContent = "Close";
btns[0].onclick = (e) => {
clearTimeout(this.saveVisualCue);
this.close();
};
btns.unshift(
$el("button", {
type: "button",
@ -71,25 +75,6 @@ class ManageTemplates extends ComfyDialog {
}
}
save() {
// Find all visible inputs and save them as our new list
const inputs = this.element.querySelectorAll("input");
const updated = [];
for (let i = 0; i < inputs.length; i++) {
const input = inputs[i];
if (input.parentElement.style.display !== "none") {
const t = this.templates[i];
t.name = input.value.trim() || input.getAttribute("data-name");
updated.push(t);
}
}
this.templates = updated;
this.store();
this.close();
}
store() {
localStorage.setItem(id, JSON.stringify(this.templates));
}
@ -145,71 +130,155 @@ class ManageTemplates extends ComfyDialog {
super.show(
$el(
"div",
{
style: {
display: "grid",
gridTemplateColumns: "1fr auto",
gap: "5px",
},
},
this.templates.flatMap((t) => {
{},
this.templates.flatMap((t,i) => {
let nameInput;
return [
$el(
"label",
"div",
{
textContent: "Name: ",
dataset: { id: i },
className: "tempateManagerRow",
style: {
display: "grid",
gridTemplateColumns: "1fr auto",
border: "1px dashed transparent",
gap: "5px",
backgroundColor: "var(--comfy-menu-bg)"
},
ondragstart: (e) => {
this.draggedEl = e.currentTarget;
e.currentTarget.style.opacity = "0.6";
e.currentTarget.style.border = "1px dashed yellow";
e.dataTransfer.effectAllowed = 'move';
e.dataTransfer.setDragImage(this.emptyImg, 0, 0);
},
ondragend: (e) => {
e.target.style.opacity = "1";
e.currentTarget.style.border = "1px dashed transparent";
e.currentTarget.removeAttribute("draggable");
// rearrange the elements in the localStorage
this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
var prev_i = el.dataset.id;
if ( el == this.draggedEl && prev_i != i ) {
[this.templates[i], this.templates[prev_i]] = [this.templates[prev_i], this.templates[i]];
}
el.dataset.id = i;
});
this.store();
},
ondragover: (e) => {
e.preventDefault();
if ( e.currentTarget == this.draggedEl )
return;
let rect = e.currentTarget.getBoundingClientRect();
if (e.clientY > rect.top + rect.height / 2) {
e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget.nextSibling);
} else {
e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget);
}
}
},
[
$el("input", {
value: t.name,
dataset: { name: t.name },
$: (el) => (nameInput = el),
}),
$el(
"label",
{
textContent: "Name: ",
style: {
cursor: "grab",
},
onmousedown: (e) => {
// enable dragging only from the label
if (e.target.localName == 'label')
e.currentTarget.parentNode.draggable = 'true';
}
},
[
$el("input", {
value: t.name,
dataset: { name: t.name },
style: {
transitionProperty: 'background-color',
transitionDuration: '0s',
},
onchange: (e) => {
clearTimeout(this.saveVisualCue);
var el = e.target;
var row = el.parentNode.parentNode;
this.templates[row.dataset.id].name = el.value.trim() || 'untitled';
this.store();
el.style.backgroundColor = 'rgb(40, 95, 40)';
el.style.transitionDuration = '0s';
this.saveVisualCue = setTimeout(function () {
el.style.transitionDuration = '.7s';
el.style.backgroundColor = 'var(--comfy-input-bg)';
}, 15);
},
onkeypress: (e) => {
var el = e.target;
clearTimeout(this.saveVisualCue);
el.style.transitionDuration = '0s';
el.style.backgroundColor = 'var(--comfy-input-bg)';
},
$: (el) => (nameInput = el),
})
]
),
$el(
"div",
{},
[
$el("button", {
textContent: "Export",
style: {
fontSize: "12px",
fontWeight: "normal",
},
onclick: (e) => {
const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: (nameInput.value || t.name) + ".json",
style: {display: "none"},
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
},
}),
$el("button", {
textContent: "Delete",
style: {
fontSize: "12px",
color: "red",
fontWeight: "normal",
},
onclick: (e) => {
const item = e.target.parentNode.parentNode;
item.parentNode.removeChild(item);
this.templates.splice(item.dataset.id*1, 1);
this.store();
// update the rows index, setTimeout ensures that the list is updated
var that = this;
setTimeout(function (){
that.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
el.dataset.id = i;
});
}, 0);
},
}),
]
),
]
),
$el(
"div",
{},
[
$el("button", {
textContent: "Export",
style: {
fontSize: "12px",
fontWeight: "normal",
},
onclick: (e) => {
const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: (nameInput.value || t.name) + ".json",
style: {display: "none"},
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
},
}),
$el("button", {
textContent: "Delete",
style: {
fontSize: "12px",
color: "red",
fontWeight: "normal",
},
onclick: (e) => {
nameInput.value = "";
e.target.parentElement.style.display = "none";
e.target.parentElement.previousElementSibling.style.display = "none";
},
}),
]
),
)
];
})
)

View File

@ -0,0 +1,150 @@
import { app } from "../../scripts/app.js";
const MAX_HISTORY = 50;
let undo = [];
let redo = [];
let activeState = null;
let isOurLoad = false;
function checkState() {
const currentState = app.graph.serialize();
if (!graphEqual(activeState, currentState)) {
undo.push(activeState);
if (undo.length > MAX_HISTORY) {
undo.shift();
}
activeState = clone(currentState);
redo.length = 0;
}
}
const loadGraphData = app.loadGraphData;
app.loadGraphData = async function () {
const v = await loadGraphData.apply(this, arguments);
if (isOurLoad) {
isOurLoad = false;
} else {
checkState();
}
return v;
};
function clone(obj) {
try {
if (typeof structuredClone !== "undefined") {
return structuredClone(obj);
}
} catch (error) {
// structuredClone is stricter than using JSON.parse/stringify so fallback to that
}
return JSON.parse(JSON.stringify(obj));
}
function graphEqual(a, b, root = true) {
if (a === b) return true;
if (typeof a == "object" && a && typeof b == "object" && b) {
const keys = Object.getOwnPropertyNames(a);
if (keys.length != Object.getOwnPropertyNames(b).length) {
return false;
}
for (const key of keys) {
let av = a[key];
let bv = b[key];
if (root && key === "nodes") {
// Nodes need to be sorted as the order changes when selecting nodes
av = [...av].sort((a, b) => a.id - b.id);
bv = [...bv].sort((a, b) => a.id - b.id);
}
if (!graphEqual(av, bv, false)) {
return false;
}
}
return true;
}
return false;
}
const undoRedo = async (e) => {
if (e.ctrlKey || e.metaKey) {
if (e.key === "y") {
const prevState = redo.pop();
if (prevState) {
undo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
} else if (e.key === "z") {
const prevState = undo.pop();
if (prevState) {
redo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
}
}
};
const bindInput = (activeEl) => {
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") {
for (const evt of ["change", "input", "blur"]) {
if (`on${evt}` in activeEl) {
const listener = () => {
checkState();
activeEl.removeEventListener(evt, listener);
};
activeEl.addEventListener(evt, listener);
return true;
}
}
}
};
window.addEventListener(
"keydown",
(e) => {
requestAnimationFrame(async () => {
const activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
// Ignore events on inputs, they have their native history
return;
}
// Check if this is a ctrl+z ctrl+y
if (await undoRedo(e)) return;
// If our active element is some type of input then handle changes after they're done
if (bindInput(activeEl)) return;
checkState();
});
},
true
);
// Handle clicking DOM elements (e.g. widgets)
window.addEventListener("mouseup", () => {
checkState();
});
// Handle litegraph clicks
const processMouseUp = LGraphCanvas.prototype.processMouseUp;
LGraphCanvas.prototype.processMouseUp = function (e) {
const v = processMouseUp.apply(this, arguments);
checkState();
return v;
};
const processMouseDown = LGraphCanvas.prototype.processMouseDown;
LGraphCanvas.prototype.processMouseDown = function (e) {
const v = processMouseDown.apply(this, arguments);
checkState();
return v;
};

View File

@ -1,4 +1,4 @@
import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js";
const CONVERTED_TYPE = "converted-widget";
@ -467,7 +467,11 @@ app.registerExtension({
if (!control_value) {
control_value = "fixed";
}
addValueControlWidget(this, widget, control_value);
addValueControlWidgets(this, widget, control_value);
let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
}
// When our value changes, update other widgets to reflect our changes

View File

@ -2533,7 +2533,7 @@
var w = this.widgets[i];
if(!w)
continue;
if(w.options && w.options.property && this.properties[ w.options.property ])
if(w.options && w.options.property && (this.properties[ w.options.property ] != undefined))
w.value = JSON.parse( JSON.stringify( this.properties[ w.options.property ] ) );
}
if (info.widgets_values) {
@ -5717,10 +5717,10 @@ LGraphNode.prototype.executeAction = function(action)
* @method enableWebGL
**/
LGraphCanvas.prototype.enableWebGL = function() {
if (typeof GL === undefined) {
if (typeof GL === "undefined") {
throw "litegl.js must be included to use a WebGL canvas";
}
if (typeof enableWebGLCanvas === undefined) {
if (typeof enableWebGLCanvas === "undefined") {
throw "webglCanvas.js must be included to use this feature";
}
@ -7113,15 +7113,16 @@ LGraphNode.prototype.executeAction = function(action)
}
};
LGraphCanvas.prototype.copyToClipboard = function() {
LGraphCanvas.prototype.copyToClipboard = function(nodes) {
var clipboard_info = {
nodes: [],
links: []
};
var index = 0;
var selected_nodes_array = [];
for (var i in this.selected_nodes) {
var node = this.selected_nodes[i];
if (!nodes) nodes = this.selected_nodes;
for (var i in nodes) {
var node = nodes[i];
if (node.clonable === false)
continue;
node._relative_id = index;
@ -11705,7 +11706,7 @@ LGraphNode.prototype.executeAction = function(action)
default:
iS = 0; // try with first if no name set
}
if (typeof options.node_from.outputs[iS] !== undefined){
if (typeof options.node_from.outputs[iS] !== "undefined"){
if (iS!==false && iS>-1){
options.node_from.connectByType( iS, node, options.node_from.outputs[iS].type );
}
@ -11733,7 +11734,7 @@ LGraphNode.prototype.executeAction = function(action)
default:
iS = 0; // try with first if no name set
}
if (typeof options.node_to.inputs[iS] !== undefined){
if (typeof options.node_to.inputs[iS] !== "undefined"){
if (iS!==false && iS>-1){
// try connection
options.node_to.connectByTypeOutput(iS,node,options.node_to.inputs[iS].type);

View File

@ -254,9 +254,9 @@ class ComfyApi extends EventTarget {
* Gets the prompt execution history
* @returns Prompt history including node outputs
*/
async getHistory() {
async getHistory(max_items=200) {
try {
const res = await this.fetchApi("/history");
const res = await this.fetchApi(`/history?max_items=${max_items}`);
return { History: Object.values(await res.json()) };
} catch (error) {
console.error(error);

View File

@ -3,7 +3,26 @@ import { ComfyWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
import { addDomClippingSetting } from "./domWidget.js";
import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js"
export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview"
function sanitizeNodeName(string) {
let entityMap = {
'&': '',
'<': '',
'>': '',
'"': '',
"'": '',
'`': '',
'=': ''
};
return String(string).replace(/[&<>"'`=]/g, function fromEntityMap (s) {
return entityMap[s];
});
}
/**
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
@ -389,7 +408,9 @@ export class ComfyApp {
return shiftY;
}
node.prototype.setSizeForImage = function () {
node.prototype.setSizeForImage = function (force) {
if(!force && this.animatedImages) return;
if (this.inputHeight) {
this.setSize(this.size);
return;
@ -406,13 +427,20 @@ export class ComfyApp {
let imagesChanged = false
const output = app.nodeOutputs[this.id + ""];
if (output && output.images) {
if (output?.images) {
this.animatedImages = output?.animated?.find(Boolean);
if (this.images !== output.images) {
this.images = output.images;
imagesChanged = true;
imgURLs = imgURLs.concat(output.images.map(params => {
return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam());
}))
imgURLs = imgURLs.concat(
output.images.map((params) => {
return api.apiURL(
"/view?" +
new URLSearchParams(params).toString() +
(this.animatedImages ? "" : app.getPreviewFormatParam())
);
})
);
}
}
@ -491,7 +519,35 @@ export class ComfyApp {
return true;
}
if (this.imgs && this.imgs.length) {
if (this.imgs?.length) {
const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET);
if(this.animatedImages) {
// Instead of using the canvas we'll use a IMG
if(widgetIdx > -1) {
// Replace content
const widget = this.widgets[widgetIdx];
widget.options.host.updateImages(this.imgs);
} else {
const host = createImageHost(this);
this.setSizeForImage(true);
const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, {
host,
getHeight: host.getHeight,
onDraw: host.onDraw,
hideOnZoom: false
});
widget.serializeValue = () => undefined;
widget.options.host.updateImages(this.imgs);
}
return;
}
if (widgetIdx > -1) {
this.widgets[widgetIdx].onRemove?.();
this.widgets.splice(widgetIdx, 1);
}
const canvas = app.graph.list_of_graphcanvas[0];
const mouse = canvas.graph_mouse;
if (!canvas.pointer_is_down && this.pointerDown) {
@ -531,31 +587,7 @@ export class ComfyApp {
}
else {
cell_padding = 0;
let best = 0;
let w = this.imgs[0].naturalWidth;
let h = this.imgs[0].naturalHeight;
// compact style
for (let c = 1; c <= numImages; c++) {
const rows = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / rows;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
shiftX = c * ((cW - imageW) / 2);
}
}
({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh));
}
let anyHovered = false;
@ -1256,6 +1288,7 @@ export class ComfyApp {
canvasEl.tabIndex = "1";
document.body.prepend(canvasEl);
addDomClippingSetting();
this.#addProcessMouseHandler();
this.#addProcessKeyHandler();
this.#addConfigureHandler();
@ -1453,6 +1486,17 @@ export class ComfyApp {
localStorage.setItem("litegrapheditor_clipboard", old);
}
showMissingNodesError(missingNodeTypes, hasAddedNodes = true) {
this.ui.dialog.show(
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map(
(t) => `<li>${t}</li>`
).join("")}</ul>${hasAddedNodes ? "Nodes that have failed to load will show as red on the graph." : ""}`
);
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
});
}
/**
* Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object
@ -1462,24 +1506,28 @@ export class ComfyApp {
let reset_invalid_values = false;
if (!graphData) {
if (typeof structuredClone === "undefined")
{
graphData = JSON.parse(JSON.stringify(defaultGraph));
}else
{
graphData = structuredClone(defaultGraph);
}
graphData = defaultGraph;
reset_invalid_values = true;
}
if (typeof structuredClone === "undefined")
{
graphData = JSON.parse(JSON.stringify(graphData));
}else
{
graphData = structuredClone(graphData);
}
const missingNodeTypes = [];
for (let n of graphData.nodes) {
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix
if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix
// Find missing node types
if (!(n.type in LiteGraph.registered_node_types)) {
n.type = sanitizeNodeName(n.type);
missingNodeTypes.push(n.type);
}
}
@ -1570,14 +1618,7 @@ export class ComfyApp {
}
if (missingNodeTypes.length) {
this.ui.dialog.show(
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map(
(t) => `<li>${t}</li>`
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
);
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
});
this.showMissingNodesError(missingNodeTypes);
}
}
@ -1589,7 +1630,17 @@ export class ComfyApp {
* @returns The workflow and node links
*/
async graphToPrompt(graph=this.graph, nodeIdOffset=0, path="/", globalMappings={}) {
const workflow = graph.serialize();
for (const node of this.graph.computeExecutionOrder(false)) {
if (node.isVirtualNode) {
// Don't serialize frontend only nodes but let them make changes
if (node.applyToGraph) {
node.applyToGraph();
}
continue;
}
}
const workflow = this.graph.serialize();
const output = {};
let childNodeCount = 0;
@ -1598,10 +1649,6 @@ export class ComfyApp {
const n = workflow.nodes.find((n) => n.id === node.id);
if (node.isVirtualNode) {
// Don't serialize frontend only nodes but let them make changes
if (node.applyToGraph) {
node.applyToGraph(workflow);
}
continue;
}
@ -1862,12 +1909,23 @@ export class ComfyApp {
importA1111(this.graph, pngInfo.parameters);
}
}
} else if (file.type === "image/webp") {
const pngInfo = await getWebpMetadata(file);
if (pngInfo) {
if (pngInfo.workflow) {
this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo.Workflow) {
this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node.
}
}
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader();
reader.onload = () => {
var jsonContent = JSON.parse(reader.result);
const jsonContent = JSON.parse(reader.result);
if (jsonContent?.templates) {
this.loadTemplateData(jsonContent);
} else if(this.isApiJson(jsonContent)) {
this.loadApiJson(jsonContent);
} else {
this.loadGraphData(jsonContent);
}
@ -1881,6 +1939,51 @@ export class ComfyApp {
}
}
isApiJson(data) {
return Object.values(data).every((v) => v.class_type);
}
loadApiJson(apiData) {
const missingNodeTypes = Object.values(apiData).filter((n) => !LiteGraph.registered_node_types[n.class_type]);
if (missingNodeTypes.length) {
this.showMissingNodesError(missingNodeTypes.map(t => t.class_type), false);
return;
}
const ids = Object.keys(apiData);
app.graph.clear();
for (const id of ids) {
const data = apiData[id];
const node = LiteGraph.createNode(data.class_type);
node.id = isNaN(+id) ? id : +id;
graph.add(node);
}
for (const id of ids) {
const data = apiData[id];
const node = app.graph.getNodeById(id);
for (const input in data.inputs ?? {}) {
const value = data.inputs[input];
if (value instanceof Array) {
const [fromId, fromSlot] = value;
const fromNode = app.graph.getNodeById(fromId);
const toSlot = node.inputs?.findIndex((inp) => inp.name === input);
if (toSlot !== -1) {
fromNode.connect(fromSlot, node, toSlot);
}
} else {
const widget = node.widgets?.find((w) => w.name === input);
if (widget) {
widget.value = value;
widget.callback?.(value);
}
}
}
}
app.graph.arrange();
}
/**
* Registers a Comfy web extension with the app
* @param {ComfyExtension} extension

323
web/scripts/domWidget.js Normal file
View File

@ -0,0 +1,323 @@
import { app, ANIM_PREVIEW_WIDGET } from "./app.js";
const SIZE = Symbol();
function intersect(a, b) {
const x = Math.max(a.x, b.x);
const num1 = Math.min(a.x + a.width, b.x + b.width);
const y = Math.max(a.y, b.y);
const num2 = Math.min(a.y + a.height, b.y + b.height);
if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y];
else return null;
}
function getClipPath(node, element, elRect) {
const selectedNode = Object.values(app.canvas.selected_nodes)[0];
if (selectedNode && selectedNode !== node) {
const MARGIN = 7;
const scale = app.canvas.ds.scale;
const bounding = selectedNode.getBounding();
const intersection = intersect(
{ x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale },
{
x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN,
y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN,
width: bounding[2] + MARGIN + MARGIN,
height: bounding[3] + MARGIN + MARGIN,
}
);
if (!intersection) {
return "";
}
const widgetRect = element.getBoundingClientRect();
const clipX = intersection[0] - widgetRect.x / scale + "px";
const clipY = intersection[1] - widgetRect.y / scale + "px";
const clipWidth = intersection[2] + "px";
const clipHeight = intersection[3] + "px";
const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`;
return path;
}
return "";
}
function computeSize(size) {
if (this.widgets?.[0].last_y == null) return;
let y = this.widgets[0].last_y;
let freeSpace = size[1] - y;
let widgetHeight = 0;
let dom = [];
for (const w of this.widgets) {
if (w.type === "converted-widget") {
// Ignore
delete w.computedHeight;
} else if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else if (w.element) {
// Extract DOM widget size info
const styles = getComputedStyle(w.element);
let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height"));
let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height"));
let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height");
if (prefHeight.endsWith?.("%")) {
prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100);
} else {
prefHeight = parseInt(prefHeight);
if (isNaN(minHeight)) {
minHeight = prefHeight;
}
}
if (isNaN(minHeight)) {
minHeight = 50;
}
if (!isNaN(maxHeight)) {
if (!isNaN(prefHeight)) {
prefHeight = Math.min(prefHeight, maxHeight);
} else {
prefHeight = maxHeight;
}
}
dom.push({
minHeight,
prefHeight,
w,
});
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
freeSpace -= widgetHeight;
// Calculate sizes with all widgets at their min height
const prefGrow = []; // Nodes that want to grow to their prefd size
const canGrow = []; // Nodes that can grow to auto size
let growBy = 0;
for (const d of dom) {
freeSpace -= d.minHeight;
if (isNaN(d.prefHeight)) {
canGrow.push(d);
d.w.computedHeight = d.minHeight;
} else {
const diff = d.prefHeight - d.minHeight;
if (diff > 0) {
prefGrow.push(d);
growBy += diff;
d.diff = diff;
} else {
d.w.computedHeight = d.minHeight;
}
}
}
if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) {
// Allocate space for image
freeSpace -= 220;
}
if (freeSpace < 0) {
// Not enough space for all widgets so we need to grow
size[1] -= freeSpace;
this.graph.setDirtyCanvas(true);
} else {
// Share the space between each
const growDiff = freeSpace - growBy;
if (growDiff > 0) {
// All pref sizes can be fulfilled
freeSpace = growDiff;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight;
}
} else {
// We need to grow evenly
const shared = -growDiff / prefGrow.length;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight - shared;
}
freeSpace = 0;
}
if (freeSpace > 0 && canGrow.length) {
// Grow any that are auto height
const shared = freeSpace / canGrow.length;
for (const d of canGrow) {
d.w.computedHeight += shared;
}
}
}
// Position each of the widgets
for (const w of this.widgets) {
w.y = y;
if (w.computedHeight) {
y += w.computedHeight;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen
const elementWidgets = new Set();
const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes;
LGraphCanvas.prototype.computeVisibleNodes = function () {
const visibleNodes = computeVisibleNodes.apply(this, arguments);
for (const node of app.graph._nodes) {
if (elementWidgets.has(node)) {
const hidden = visibleNodes.indexOf(node) === -1;
for (const w of node.widgets) {
if (w.element) {
w.element.hidden = hidden;
if (hidden) {
w.options.onHide?.(w);
}
}
}
}
}
return visibleNodes;
};
let enableDomClipping = true;
export function addDomClippingSetting() {
app.ui.settings.addSetting({
id: "Comfy.DOMClippingEnabled",
name: "Enable DOM element clipping (enabling may reduce performance)",
type: "boolean",
defaultValue: enableDomClipping,
onChange(value) {
console.log("enableDomClipping", enableDomClipping);
enableDomClipping = !!value;
},
});
}
LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options };
if (!element.parentElement) {
document.body.append(element);
}
let mouseDownHandler;
if (element.blur) {
mouseDownHandler = (event) => {
if (!element.contains(event.target)) {
element.blur();
}
};
document.addEventListener("mousedown", mouseDownHandler);
}
const widget = {
type,
name,
get value() {
return options.getValue?.() ?? undefined;
},
set value(v) {
options.setValue?.(v);
widget.callback?.(widget.value);
},
draw: function (ctx, node, widgetWidth, y, widgetHeight) {
if (widget.computedHeight == null) {
computeSize.call(node, node.size);
}
const hidden =
node.flags?.collapsed ||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
widget.computedHeight <= 0 ||
widget.type === "converted-widget";
element.hidden = hidden;
element.style.display = hidden ? "none" : null;
if (hidden) {
widget.options.onHide?.(widget);
return;
}
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d);
Object.assign(element.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - margin * 2}px`,
height: `${(widget.computedHeight ?? 50) - margin * 2}px`,
position: "absolute",
zIndex: app.graph._nodes.indexOf(node),
});
if (enableDomClipping) {
element.style.clipPath = getClipPath(node, element, elRect);
element.style.willChange = "clip-path";
}
this.options.onDraw?.(widget);
},
element,
options,
onRemove() {
if (mouseDownHandler) {
document.removeEventListener("mousedown", mouseDownHandler);
}
element.remove();
},
};
for (const evt of options.selectOn) {
element.addEventListener(evt, () => {
app.canvas.selectNode(this);
app.canvas.bringToFront(this);
});
}
this.addCustomWidget(widget);
elementWidgets.add(this);
const collapse = this.collapse;
this.collapse = function() {
collapse.apply(this, arguments);
if(this.flags?.collapsed) {
element.hidden = true;
element.style.display = "none";
}
}
const onRemoved = this.onRemoved;
this.onRemoved = function () {
element.remove();
elementWidgets.delete(this);
onRemoved?.apply(this, arguments);
};
if (!this[SIZE]) {
this[SIZE] = true;
const onResize = this.onResize;
this.onResize = function (size) {
options.beforeResize?.call(widget, this);
computeSize.call(this, size);
onResize?.apply(this, arguments);
options.afterResize?.call(widget, this);
};
}
return widget;
};

View File

@ -24,7 +24,7 @@ export function getPngMetadata(file) {
const length = dataView.getUint32(offset);
// Get the chunk type
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
if (type === "tEXt") {
if (type === "tEXt" || type == "comf") {
// Get the keyword
let keyword_end = offset + 8;
while (pngData[keyword_end] !== 0) {
@ -47,6 +47,105 @@ export function getPngMetadata(file) {
});
}
function parseExifData(exifData) {
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
// Function to read 16-bit and 32-bit integers from binary data
function readInt(offset, isLittleEndian, length) {
let arr = exifData.slice(offset, offset + length)
if (length === 2) {
return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint16(0, isLittleEndian);
} else if (length === 4) {
return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint32(0, isLittleEndian);
}
}
// Read the offset to the first IFD (Image File Directory)
const ifdOffset = readInt(4, isLittleEndian, 4);
function parseIFD(offset) {
const numEntries = readInt(offset, isLittleEndian, 2);
const result = {};
for (let i = 0; i < numEntries; i++) {
const entryOffset = offset + 2 + i * 12;
const tag = readInt(entryOffset, isLittleEndian, 2);
const type = readInt(entryOffset + 2, isLittleEndian, 2);
const numValues = readInt(entryOffset + 4, isLittleEndian, 4);
const valueOffset = readInt(entryOffset + 8, isLittleEndian, 4);
// Read the value(s) based on the data type
let value;
if (type === 2) {
// ASCII string
value = String.fromCharCode(...exifData.slice(valueOffset, valueOffset + numValues - 1));
}
result[tag] = value;
}
return result;
}
// Parse the first IFD
const ifdData = parseIFD(ifdOffset);
return ifdData;
}
function splitValues(input) {
var output = {};
for (var key in input) {
var value = input[key];
var splitValues = value.split(':', 2);
output[splitValues[0]] = splitValues[1];
}
return output;
}
export function getWebpMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();
reader.onload = (event) => {
const webp = new Uint8Array(event.target.result);
const dataView = new DataView(webp.buffer);
// Check that the WEBP signature is present
if (dataView.getUint32(0) !== 0x52494646 || dataView.getUint32(8) !== 0x57454250) {
console.error("Not a valid WEBP file");
r();
return;
}
// Start searching for chunks after the WEBP signature
let offset = 12;
let txt_chunks = {};
// Loop through the chunks in the WEBP file
while (offset < webp.length) {
const chunk_length = dataView.getUint32(offset + 4, true);
const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4));
if (chunk_type === "EXIF") {
if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") {
offset += 6;
}
let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length));
for (var key in data) {
var value = data[key];
let index = value.indexOf(':');
txt_chunks[value.slice(0, index)] = value.slice(index + 1);
}
}
offset += 8 + chunk_length;
}
r(txt_chunks);
};
reader.readAsArrayBuffer(file);
});
}
export function getLatentMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();

View File

@ -599,7 +599,7 @@ export class ComfyUI {
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",
accept: ".json,image/png,.latent,.safetensors",
accept: ".json,image/png,.latent,.safetensors,image/webp",
style: {display: "none"},
parent: document.body,
onchange: () => {
@ -719,20 +719,22 @@ export class ComfyUI {
filename += ".json";
}
}
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: filename,
style: {display: "none"},
parent: document.body,
app.graphToPrompt().then(p=>{
const json = JSON.stringify(p.workflow, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: filename,
style: {display: "none"},
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
},
}),
$el("button", {

View File

@ -0,0 +1,97 @@
import { $el } from "../ui.js";
export function calculateImageGrid(imgs, dw, dh) {
let best = 0;
let w = imgs[0].naturalWidth;
let h = imgs[0].naturalHeight;
const numImages = imgs.length;
let cellWidth, cellHeight, cols, rows, shiftX;
// compact style
for (let c = 1; c <= numImages; c++) {
const r = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / r;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
rows = r;
shiftX = c * ((cW - imageW) / 2);
}
}
return { cellWidth, cellHeight, cols, rows, shiftX };
}
export function createImageHost(node) {
const el = $el("div.comfy-img-preview");
let currentImgs;
let first = true;
function updateSize() {
let w = null;
let h = null;
if (currentImgs) {
let elH = el.clientHeight;
if (first) {
first = false;
// On first run, if we are small then grow a bit
if (elH < 190) {
elH = 190;
}
el.style.setProperty("--comfy-widget-min-height", elH);
} else {
el.style.setProperty("--comfy-widget-min-height", null);
}
const nw = node.size[0];
({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH));
w += "px";
h += "px";
el.style.setProperty("--comfy-img-preview-width", w);
el.style.setProperty("--comfy-img-preview-height", h);
}
}
return {
el,
updateImages(imgs) {
if (imgs !== currentImgs) {
if (currentImgs == null) {
requestAnimationFrame(() => {
updateSize();
});
}
el.replaceChildren(...imgs);
currentImgs = imgs;
node.onResize(node.size);
node.graph.setDirtyCanvas(true, true);
}
},
getHeight() {
updateSize();
},
onDraw() {
// Element from point uses a hittest find elements so we need to toggle pointer events
el.style.pointerEvents = "all";
const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]);
el.style.pointerEvents = "none";
if(!over) return;
// Set the overIndex so Open Image etc work
const idx = currentImgs.indexOf(over);
node.overIndex = idx;
},
};
}

View File

@ -1,5 +1,6 @@
import { api } from "./api.js"
import { getPngMetadata } from "./pnginfo.js";
import "./domWidget.js";
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
let defaultVal = inputData[1]["default"];
@ -24,17 +25,58 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
}
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) {
const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, {
const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, {
addFilterList: false,
});
return widgets[0];
}
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) {
if (!options) options = {};
const widgets = [];
const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, {
values: ["fixed", "increment", "decrement", "randomize"],
serialize: false, // Don't include this in prompt.
});
valueControl.afterQueued = () => {
widgets.push(valueControl);
const isCombo = targetWidget.type === "combo";
let comboFilter;
if (isCombo && options.addFilterList !== false) {
comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, {
serialize: false, // Don't include this in prompt.
});
widgets.push(comboFilter);
}
valueControl.afterQueued = () => {
var v = valueControl.value;
if (targetWidget.type == "combo" && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
let current_length = targetWidget.options.values.length;
if (isCombo && v !== "fixed") {
let values = targetWidget.options.values;
const filter = comboFilter?.value;
if (filter) {
let check;
if (filter.startsWith("/") && filter.endsWith("/")) {
try {
const regex = new RegExp(filter.substring(1, filter.length - 1));
check = (item) => regex.test(item);
} catch (error) {
console.error("Error constructing RegExp filter for node " + node.id, filter, error);
}
}
if (!check) {
const lower = filter.toLocaleLowerCase();
check = (item) => item.toLocaleLowerCase().includes(lower);
}
values = values.filter(item => check(item));
if (!values.length && targetWidget.options.values.length) {
console.warn("Filter for node " + node.id + " has filtered out all items", filter);
}
}
let current_index = values.indexOf(targetWidget.value);
let current_length = values.length;
switch (v) {
case "increment":
@ -51,7 +93,7 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) {
let value = targetWidget.options.values[current_index];
let value = values[current_index];
targetWidget.value = value;
targetWidget.callback(value);
}
@ -88,7 +130,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
targetWidget.callback(targetWidget.value);
}
}
return valueControl;
return widgets;
};
function seedWidget(node, inputName, inputData, app) {
@ -98,166 +141,21 @@ function seedWidget(node, inputName, inputData, app) {
seed.widget.linkedWidgets = [seedControl];
return seed;
}
const MultilineSymbol = Symbol();
const MultilineResizeSymbol = Symbol();
function addMultilineWidget(node, name, opts, app) {
const MIN_SIZE = 50;
const inputEl = document.createElement("textarea");
inputEl.className = "comfy-multiline-input";
inputEl.value = opts.defaultVal;
inputEl.placeholder = opts.placeholder || "";
function computeSize(size) {
if (node.widgets[0].last_y == null) return;
let y = node.widgets[0].last_y;
let freeSpace = size[1] - y;
// Compute the height of all non customtext widgets
let widgetHeight = 0;
const multi = [];
for (let i = 0; i < node.widgets.length; i++) {
const w = node.widgets[i];
if (w.type === "customtext") {
multi.push(w);
} else {
if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// See how large each text input can be
freeSpace -= widgetHeight;
freeSpace /= multi.length + (!!node.imgs?.length);
if (freeSpace < MIN_SIZE) {
// There isnt enough space for all the widgets, increase the size of the node
freeSpace = MIN_SIZE;
node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
node.graph.setDirtyCanvas(true);
}
// Position each of the widgets
for (const w of node.widgets) {
w.y = y;
if (w.type === "customtext") {
y += freeSpace;
w.computedHeight = freeSpace - multi.length*4;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
node.inputHeight = freeSpace;
}
const widget = {
type: "customtext",
name,
get value() {
return this.inputEl.value;
const widget = node.addDOMWidget(name, "customtext", inputEl, {
getValue() {
return inputEl.value;
},
set value(x) {
this.inputEl.value = x;
setValue(v) {
inputEl.value = v;
},
draw: function (ctx, _, widgetWidth, y, widgetHeight) {
if (!this.parent.inputHeight) {
// If we are initially offscreen when created we wont have received a resize event
// Calculate it here instead
computeSize(node.size);
}
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d)
Object.assign(this.inputEl.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - (margin * 2)}px`,
height: `${this.parent.inputHeight - (margin * 2)}px`,
position: "absolute",
background: (!node.color)?'':node.color,
color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
});
this.inputEl.hidden = !visible;
},
};
widget.inputEl = document.createElement("textarea");
widget.inputEl.className = "comfy-multiline-input";
widget.inputEl.value = opts.defaultVal;
widget.inputEl.placeholder = opts.placeholder || "";
document.addEventListener("mousedown", function (event) {
if (!widget.inputEl.contains(event.target)) {
widget.inputEl.blur();
}
});
widget.parent = node;
document.body.appendChild(widget.inputEl);
node.addCustomWidget(widget);
app.canvas.onDrawBackground = function () {
// Draw node isnt fired once the node is off the screen
// if it goes off screen quickly, the input may not be removed
// this shifts it off screen so it can be moved back if the node is visible.
for (let n in app.graph._nodes) {
n = graph._nodes[n];
for (let w in n.widgets) {
let wid = n.widgets[w];
if (Object.hasOwn(wid, "inputEl")) {
wid.inputEl.style.left = -8000 + "px";
wid.inputEl.style.position = "absolute";
}
}
}
};
node.onRemoved = function () {
// When removing this node we need to remove the input from the DOM
for (let y in this.widgets) {
if (this.widgets[y].inputEl) {
this.widgets[y].inputEl.remove();
}
}
};
widget.onRemove = () => {
widget.inputEl?.remove();
// Restore original size handler if we are the last
if (!--node[MultilineSymbol]) {
node.onResize = node[MultilineResizeSymbol];
delete node[MultilineSymbol];
delete node[MultilineResizeSymbol];
}
};
if (node[MultilineSymbol]) {
node[MultilineSymbol]++;
} else {
node[MultilineSymbol] = 1;
const onResize = (node[MultilineResizeSymbol] = node.onResize);
node.onResize = function (size) {
computeSize(size);
// Call original resizer handler
if (onResize) {
onResize.apply(this, arguments);
}
};
}
widget.inputEl = inputEl;
return { minWidth: 400, minHeight: 200, widget };
}
@ -306,14 +204,23 @@ export const ComfyWidgets = {
};
},
BOOLEAN(node, inputName, inputData) {
let defaultVal = inputData[1]["default"];
let defaultVal = false;
let options = {};
if (inputData[1]) {
if (inputData[1].default)
defaultVal = inputData[1].default;
if (inputData[1].label_on)
options["on"] = inputData[1].label_on;
if (inputData[1].label_off)
options["off"] = inputData[1].label_off;
}
return {
widget: node.addWidget(
"toggle",
inputName,
defaultVal,
() => {},
{"on": inputData[1].label_on, "off": inputData[1].label_off}
options,
)
};
},

View File

@ -409,6 +409,21 @@ dialog::backdrop {
width: calc(100% - 10px);
}
.comfy-img-preview {
pointer-events: none;
overflow: hidden;
display: flex;
flex-wrap: wrap;
align-content: flex-start;
justify-content: center;
}
.comfy-img-preview img {
object-fit: contain;
width: var(--comfy-img-preview-width);
height: var(--comfy-img-preview-height);
}
/* Search box */
.litegraph.litesearchbox {