diff --git a/README.md b/README.md index d622c9072..af1f22811 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. -- Fully supports SD1.x, SD2.x and SDXL +- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) @@ -30,6 +30,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) +- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) +- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. - Works fully offline: will never download anything. diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index f982d648c..76a525b37 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -27,7 +27,6 @@ class ControlNet(nn.Module): model_channels, hint_channels, num_res_blocks, - attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, @@ -52,8 +51,10 @@ class ControlNet(nn.Module): use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + transformer_depth_output=None, device=None, operations=comfy.ops, + **kwargs, ): super().__init__() assert use_spatial_transformer == True, "use_spatial_transformer has to be true" @@ -79,10 +80,7 @@ class ControlNet(nn.Module): self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - if transformer_depth_middle is None: - transformer_depth_middle = transformer_depth[-1] + if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: @@ -90,18 +88,16 @@ class ControlNet(nn.Module): raise ValueError("provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult") self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - self.attention_resolutions = attention_resolutions + transformer_depth = transformer_depth[:] + self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample @@ -180,11 +176,14 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, - operations=operations + dtype=self.dtype, + device=device, + operations=operations, ) ] ch = mult * model_channels - if ds in attention_resolutions: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -201,9 +200,9 @@ class ControlNet(nn.Module): if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: layers.append( SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, operations=operations + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -223,11 +222,13 @@ class ControlNet(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, + dtype=self.dtype, + device=device, operations=operations ) if resblock_updown else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations + ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations ) ) ) @@ -245,7 +246,7 @@ class ControlNet(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( + mid_block = [ ResBlock( ch, time_embed_dim, @@ -253,12 +254,15 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, operations=operations - ), - SpatialTransformer( # always uses a self-attn + )] + if transformer_depth_middle >= 0: + mid_block += [SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, operations=operations + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ), ResBlock( ch, @@ -267,9 +271,11 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, operations=operations - ), - ) + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self.middle_block_out = self.make_zero_conv(ch, operations=operations) self._feature_size += ch diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d86557646..72fce1087 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -36,6 +36,8 @@ parser = argparse.ArgumentParser() parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") +parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.") + parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).") @@ -60,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") +fpte_group = parser.add_mutually_exclusive_group() +fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") +fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") +fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") +fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") + + parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index e085186ef..9e2e03d72 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,5 +1,5 @@ -from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils -from .utils import load_torch_file, transformers_convert +from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils +from .utils import load_torch_file, transformers_convert, common_upscale import os import torch import contextlib @@ -7,6 +7,18 @@ import contextlib import comfy.ops import comfy.model_patcher import comfy.model_management +import comfy.utils + +def clip_preprocess(image, size=224): + mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) + std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) + scale = (size / min(image.shape[1], image.shape[2])) + image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3,1,1])) / std.view([3,1,1]) class ClipVisionModel(): def __init__(self, json_config): @@ -23,25 +35,12 @@ class ClipVisionModel(): self.model.to(self.dtype) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) - self.processor = CLIPImageProcessor(crop_size=224, - do_center_crop=True, - do_convert_rgb=True, - do_normalize=True, - do_resize=True, - image_mean=[ 0.48145466,0.4578275,0.40821073], - image_std=[0.26862954,0.26130258,0.27577711], - resample=3, #bicubic - size=224) - def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) def encode_image(self, image): - img = torch.clip((255. * image), 0, 255).round().int() - img = list(map(lambda a: a, img)) - inputs = self.processor(images=img, return_tensors="pt") comfy.model_management.load_model_gpu(self.patcher) - pixel_values = inputs['pixel_values'].to(self.load_device) + pixel_values = clip_preprocess(image.to(self.load_device)) if self.dtype != torch.float32: precision_scope = torch.autocast diff --git a/comfy/conds.py b/comfy/conds.py new file mode 100644 index 000000000..6cff25184 --- /dev/null +++ b/comfy/conds.py @@ -0,0 +1,79 @@ +import enum +import torch +import math +import comfy.utils + + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) + +class CONDRegular: + def __init__(self, cond): + self.cond = cond + + def _copy_with(self, cond): + return self.__class__(cond) + + def process_cond(self, batch_size, device, **kwargs): + return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + + def can_concat(self, other): + if self.cond.shape != other.cond.shape: + return False + return True + + def concat(self, others): + conds = [self.cond] + for x in others: + conds.append(x.cond) + return torch.cat(conds) + +class CONDNoiseShape(CONDRegular): + def process_cond(self, batch_size, device, area, **kwargs): + data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) + + +class CONDCrossAttn(CONDRegular): + def can_concat(self, other): + s1 = self.cond.shape + s2 = other.cond.shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False + return True + + def concat(self, others): + conds = [self.cond] + crossattn_max_len = self.cond.shape[1] + for x in others: + c = x.cond + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + conds.append(c) + + out = [] + for c in conds: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + out.append(c) + return torch.cat(out) + +class CONDConstant(CONDRegular): + def __init__(self, cond): + self.cond = cond + + def process_cond(self, batch_size, device, **kwargs): + return self._copy_with(self.cond) + + def can_concat(self, other): + if self.cond != other.cond: + return False + return True + + def concat(self, others): + return self.cond diff --git a/comfy/controlnet.py b/comfy/controlnet.py index f1355e64e..433381df6 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -33,7 +33,7 @@ class ControlBase: self.cond_hint_original = None self.cond_hint = None self.strength = 1.0 - self.timestep_percent_range = (1.0, 0.0) + self.timestep_percent_range = (0.0, 1.0) self.timestep_range = None if device is None: @@ -42,7 +42,7 @@ class ControlBase: self.previous_controlnet = None self.global_average_pooling = False - def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): self.cond_hint_original = cond_hint self.strength = strength self.timestep_percent_range = timestep_percent_range @@ -132,6 +132,7 @@ class ControlNet(ControlBase): self.control_model = control_model self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling + self.model_sampling_current = None def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -156,10 +157,13 @@ class ControlNet(ControlBase): context = cond['c_crossattn'] - y = cond.get('c_adm', None) + y = cond.get('y', None) if y is not None: y = y.to(self.control_model.dtype) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) + timestep = self.model_sampling_current.timestep(t) + x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) + + control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): @@ -172,6 +176,14 @@ class ControlNet(ControlBase): out.append(self.control_model_wrapped) return out + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + self.model_sampling_current = model.model_sampling + + def cleanup(self): + self.model_sampling_current = None + super().cleanup() + class ControlLoraOps: class Linear(torch.nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True, diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 58e030d04..08bf0fc9e 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -852,7 +852,13 @@ class SigmaConvert: log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) return log_mean_coeff - log_std -def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): +def predict_eps_sigma(model, input, sigma_in, **kwargs): + sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1)) + input = input * ((sigma ** 2 + 1.0) ** 0.5) + return (input - model(input, sigma_in, **kwargs)) / sigma + + +def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): timesteps = sigmas.clone() if sigmas[-1] == 0: timesteps = sigmas[:] @@ -874,14 +880,14 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex model_type = "noise" model_fn = model_wrapper( - model.predict_eps_sigma, + lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs), ns, model_type=model_type, guidance_type="uncond", model_kwargs=extra_args, ) - order = min(3, len(timesteps) - 1) + order = min(3, len(timesteps) - 2) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x /= ns.marginal_alpha(timesteps[-1]) diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py deleted file mode 100644 index 953d3db2c..000000000 --- a/comfy/k_diffusion/external.py +++ /dev/null @@ -1,194 +0,0 @@ -import math - -import torch -from torch import nn - -from . import sampling, utils - - -class VDenoiser(nn.Module): - """A v-diffusion-pytorch model wrapper for k-diffusion.""" - - def __init__(self, inner_model): - super().__init__() - self.inner_model = inner_model - self.sigma_data = 1. - - def get_scalings(self, sigma): - c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - return c_skip, c_out, c_in - - def sigma_to_t(self, sigma): - return sigma.atan() / math.pi * 2 - - def t_to_sigma(self, t): - return (t * math.pi / 2).tan() - - def loss(self, input, noise, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - noised_input = input + noise * utils.append_dims(sigma, input.ndim) - model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) - target = (input - c_skip * noised_input) / c_out - return (model_output - target).pow(2).flatten(1).mean(1) - - def forward(self, input, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip - - -class DiscreteSchedule(nn.Module): - """A mapping between continuous noise levels (sigmas) and a list of discrete noise - levels.""" - - def __init__(self, sigmas, quantize): - super().__init__() - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - self.quantize = quantize - - @property - def sigma_min(self): - return self.sigmas[0] - - @property - def sigma_max(self): - return self.sigmas[-1] - - def get_sigmas(self, n=None): - if n is None: - return sampling.append_zero(self.sigmas.flip(0)) - t_max = len(self.sigmas) - 1 - t = torch.linspace(t_max, 0, n, device=self.sigmas.device) - return sampling.append_zero(self.t_to_sigma(t)) - - def sigma_to_discrete_timestep(self, sigma): - log_sigma = sigma.log() - dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) - - def sigma_to_t(self, sigma, quantize=None): - quantize = self.quantize if quantize is None else quantize - if quantize: - return self.sigma_to_discrete_timestep(sigma) - log_sigma = sigma.log() - dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] - w = (low - log_sigma) / (low - high) - w = w.clamp(0, 1) - t = (1 - w) * low_idx + w * high_idx - return t.view(sigma.shape) - - def t_to_sigma(self, t): - t = t.float() - low_idx = t.floor().long() - high_idx = t.ceil().long() - w = t-low_idx if t.device.type == 'mps' else t.frac() - log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() - - def predict_eps_discrete_timestep(self, input, t, **kwargs): - if t.dtype != torch.int64 and t.dtype != torch.int32: - t = t.round() - sigma = self.t_to_sigma(t) - input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) - return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) - - def predict_eps_sigma(self, input, sigma, **kwargs): - input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) - return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) - -class DiscreteEpsDDPMDenoiser(DiscreteSchedule): - """A wrapper for discrete schedule DDPM models that output eps (the predicted - noise).""" - - def __init__(self, model, alphas_cumprod, quantize): - super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) - self.inner_model = model - self.sigma_data = 1. - - def get_scalings(self, sigma): - c_out = -sigma - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - return c_out, c_in - - def get_eps(self, *args, **kwargs): - return self.inner_model(*args, **kwargs) - - def loss(self, input, noise, sigma, **kwargs): - c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - noised_input = input + noise * utils.append_dims(sigma, input.ndim) - eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) - return (eps - noise).pow(2).flatten(1).mean(1) - - def forward(self, input, sigma, **kwargs): - c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) - return input + eps * c_out - - -class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): - """A wrapper for OpenAI diffusion models.""" - - def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): - alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) - super().__init__(model, alphas_cumprod, quantize=quantize) - self.has_learned_sigmas = has_learned_sigmas - - def get_eps(self, *args, **kwargs): - model_output = self.inner_model(*args, **kwargs) - if self.has_learned_sigmas: - return model_output.chunk(2, dim=1)[0] - return model_output - - -class CompVisDenoiser(DiscreteEpsDDPMDenoiser): - """A wrapper for CompVis diffusion models.""" - - def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod, quantize=quantize) - - def get_eps(self, *args, **kwargs): - return self.inner_model.apply_model(*args, **kwargs) - - -class DiscreteVDDPMDenoiser(DiscreteSchedule): - """A wrapper for discrete schedule DDPM models that output v.""" - - def __init__(self, model, alphas_cumprod, quantize): - super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) - self.inner_model = model - self.sigma_data = 1. - - def get_scalings(self, sigma): - c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - return c_skip, c_out, c_in - - def get_v(self, *args, **kwargs): - return self.inner_model(*args, **kwargs) - - def loss(self, input, noise, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - noised_input = input + noise * utils.append_dims(sigma, input.ndim) - model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) - target = (input - c_skip * noised_input) / c_out - return (model_output - target).pow(2).flatten(1).mean(1) - - def forward(self, input, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip - - -class CompVisVDenoiser(DiscreteVDDPMDenoiser): - """A wrapper for CompVis diffusion models that output v.""" - - def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod, quantize=quantize) - - def get_v(self, x, t, cond, **kwargs): - return self.inner_model.apply_model(x, t, cond) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 937c5a388..761c2e0ef 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) return mu - def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -737,3 +736,75 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) +@torch.no_grad() +def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + x = denoised + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + return x + + + +@torch.no_grad() +def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + s_end = sigmas[-1] + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == s_end: + # Euler method + x = x + d * dt + elif sigmas[i + 2] == s_end: + + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + + w = 2 * sigmas[0] + w2 = sigmas[i+1]/w + w1 = 1 - w2 + + d_prime = d * w1 + d_2 * w2 + + + x = x + d_prime * dt + + else: + # Heun++ + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + dt_2 = sigmas[i + 2] - sigmas[i + 1] + + x_3 = x_2 + d_2 * dt_2 + denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args) + d_3 = to_d(x_3, sigmas[i + 2], denoised_3) + + w = 3 * sigmas[0] + w2 = sigmas[i + 1] / w + w3 = sigmas[i + 2] / w + w1 = 1 - w2 - w3 + + d_prime = w1 * d + w2 * d_2 + w3 * d_3 + x = x + d_prime * dt + return x diff --git a/comfy/ldm/models/diffusion/__init__.py b/comfy/ldm/models/diffusion/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py deleted file mode 100644 index 433d48e30..000000000 --- a/comfy/ldm/models/diffusion/ddim.py +++ /dev/null @@ -1,418 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm - -from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor - - -class DDIMSampler(object): - def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - self.device = device - self.parameterization = kwargs.get("parameterization", "eps") - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != self.device: - attr = attr.float().to(self.device) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose) - - def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True): - self.ddim_timesteps = torch.tensor(ddim_timesteps) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) - - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample_custom(self, - ddim_timesteps, - conditioning=None, - callback=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - denoise_function=None, - extra_args=None, - to_zero=True, - end_step=None, - disable_pbar=False, - **kwargs - ): - self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose) - samples, intermediates = self.ddim_sampling(conditioning, x_T.shape, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule, - denoise_function=denoise_function, - extra_args=extra_args, - to_zero=to_zero, - end_step=end_step, - disable_pbar=disable_pbar - ) - return samples, intermediates - - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - elif isinstance(conditioning, list): - for ctmp in conditioning: - if ctmp.shape[0] != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule, - denoise_function=None, - extra_args=None - ) - return samples, intermediates - - def q_sample(self, x_start, t, noise=None): - if noise is None: - noise = torch.randn_like(x_start) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) - - @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False): - device = self.model.alphas_cumprod.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - # print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - if ucg_schedule is not None: - assert len(ucg_schedule) == len(time_range) - unconditional_guidance_scale = ucg_schedule[i] - - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args) - img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - if to_zero: - img = pred_x0 - else: - if ddim_use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - img /= sqrt_alphas_cumprod[index - 1] - - return img, intermediates - - @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None, denoise_function=None, extra_args=None): - b, *_, device = *x.shape, x.device - - if denoise_function is not None: - model_output = denoise_function(x, t, **extra_args) - elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: - model_output = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - if isinstance(c, dict): - assert isinstance(unconditional_conditioning, dict) - c_in = dict() - for k in c: - if isinstance(c[k], list): - c_in[k] = [torch.cat([ - unconditional_conditioning[k][i], - c[k][i]]) for i in range(len(c[k]))] - else: - c_in[k] = torch.cat([ - unconditional_conditioning[k], - c[k]]) - elif isinstance(c, list): - c_in = list() - assert isinstance(unconditional_conditioning, list) - for i in range(len(c)): - c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) - else: - c_in = torch.cat([unconditional_conditioning, c]) - model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) - - if self.parameterization == "v": - e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x - else: - e_t = model_output - - if score_corrector is not None: - assert self.parameterization == "eps", 'not implemented' - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - if self.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - else: - pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output - - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - - if dynamic_threshold is not None: - raise NotImplementedError() - - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): - num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] - - assert t_enc <= num_reference_steps - num_steps = t_enc - - if use_original_steps: - alphas_next = self.alphas_cumprod[:num_steps] - alphas = self.alphas_cumprod_prev[:num_steps] - else: - alphas_next = self.ddim_alphas[:num_steps] - alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) - - x_next = x0 - intermediates = [] - inter_steps = [] - for i in tqdm(range(num_steps), desc='Encoding Image'): - t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) - if unconditional_guidance_scale == 1.: - noise_pred = self.model.apply_model(x_next, t, c) - else: - assert unconditional_conditioning is not None - e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), - torch.cat((unconditional_conditioning, c))), 2) - noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) - - xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred - x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: - intermediates.append(x_next) - inter_steps.append(i) - elif return_intermediates and i >= num_steps - 2: - intermediates.append(x_next) - inter_steps.append(i) - if callback: callback(i) - - out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} - if return_intermediates: - out.update({'intermediates': intermediates}) - return x_next, out - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - if max_denoise: - noise_multiplier = 1.0 - else: - noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) - - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise) - - @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): - - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) - x_dec = x_latent - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec \ No newline at end of file diff --git a/comfy/ldm/models/diffusion/dpm_solver/__init__.py b/comfy/ldm/models/diffusion/dpm_solver/__init__.py deleted file mode 100644 index 7427f38c0..000000000 --- a/comfy/ldm/models/diffusion/dpm_solver/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py b/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py deleted file mode 100644 index da8d41f9c..000000000 --- a/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ /dev/null @@ -1,1163 +0,0 @@ -import torch -import torch.nn.functional as F -import math -from tqdm import tqdm - - -class NoiseScheduleVP: - def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., - ): - """Create a wrapper class for the forward SDE (VP type). - *** - Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. - We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. - *** - The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). - We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). - Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: - log_alpha_t = self.marginal_log_mean_coeff(t) - sigma_t = self.marginal_std(t) - lambda_t = self.marginal_lambda(t) - Moreover, as lambda(t) is an invertible function, we also support its inverse function: - t = self.inverse_lambda(lambda_t) - =============================================================== - We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). - 1. For discrete-time DPMs: - For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: - t_i = (i + 1) / N - e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. - We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. - Args: - betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) - alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) - Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. - **Important**: Please pay special attention for the args for `alphas_cumprod`: - The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that - q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). - Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have - alpha_{t_n} = \sqrt{\hat{alpha_n}}, - and - log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). - 2. For continuous-time DPMs: - We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise - schedule are the default settings in DDPM and improved-DDPM: - Args: - beta_min: A `float` number. The smallest beta for the linear schedule. - beta_max: A `float` number. The largest beta for the linear schedule. - cosine_s: A `float` number. The hyperparameter in the cosine schedule. - cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. - T: A `float` number. The ending time of the forward process. - =============================================================== - Args: - schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, - 'linear' or 'cosine' for continuous-time DPMs. - Returns: - A wrapper object of the forward SDE (VP type). - - =============================================================== - Example: - # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', betas=betas) - # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) - # For continuous-time DPMs (VPSDE), linear schedule: - >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) - """ - - if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError( - "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule)) - - self.schedule = schedule - if schedule == 'discrete': - if betas is not None: - log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) - else: - assert alphas_cumprod is not None - log_alphas = 0.5 * torch.log(alphas_cumprod) - self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) - else: - self.total_N = 1000 - self.beta_0 = continuous_beta_0 - self.beta_1 = continuous_beta_1 - self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) - self.schedule = schedule - if schedule == 'cosine': - # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. - # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. - self.T = 0.9946 - else: - self.T = 1. - - def marginal_log_mean_coeff(self, t): - """ - Compute log(alpha_t) of a given continuous-time label t in [0, T]. - """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), - self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t - - def marginal_alpha(self, t): - """ - Compute alpha_t of a given continuous-time label t in [0, T]. - """ - return torch.exp(self.marginal_log_mean_coeff(t)) - - def marginal_std(self, t): - """ - Compute sigma_t of a given continuous-time label t in [0, T]. - """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) - - def marginal_lambda(self, t): - """ - Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. - """ - log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) - return log_mean_coeff - log_std - - def inverse_lambda(self, lamb): - """ - Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. - """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp - return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1])) - return t.reshape((-1,)) - else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - t = t_fn(log_alpha) - return t - - -def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, -): - """Create a wrapper function for the noise prediction model. - DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to - firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. - We support four types of the diffusion model by setting `model_type`: - 1. "noise": noise prediction model. (Trained by predicting noise). - 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). - 3. "v": velocity prediction model. (Trained by predicting the velocity). - The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. - [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." - arXiv preprint arXiv:2202.00512 (2022). - [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." - arXiv preprint arXiv:2210.02303 (2022). - - 4. "score": marginal score function. (Trained by denoising score matching). - Note that the score function and the noise prediction model follows a simple relationship: - ``` - noise(x_t, t) = -sigma_t * score(x_t, t) - ``` - We support three types of guided sampling by DPMs by setting `guidance_type`: - 1. "uncond": unconditional sampling by DPMs. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - The input `classifier_fn` has the following format: - `` - classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) - `` - [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," - in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. - 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. - The input `model` has the following format: - `` - model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score - `` - And if cond == `unconditional_condition`, the model output is the unconditional DPM output. - [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." - arXiv preprint arXiv:2207.12598 (2022). - - The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) - or continuous-time labels (i.e. epsilon to T). - We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: - `` - def model_fn(x, t_continuous) -> noise: - t_input = get_model_input_time(t_continuous) - return noise_pred(model, x, t_input, **model_kwargs) - `` - where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. - =============================================================== - Args: - model: A diffusion model with the corresponding format described above. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - model_type: A `str`. The parameterization type of the diffusion model. - "noise" or "x_start" or "v" or "score". - model_kwargs: A `dict`. A dict for the other inputs of the model function. - guidance_type: A `str`. The type of the guidance for sampling. - "uncond" or "classifier" or "classifier-free". - condition: A pytorch tensor. The condition for the guided sampling. - Only used for "classifier" or "classifier-free" guidance type. - unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. - Only used for "classifier-free" guidance type. - guidance_scale: A `float`. The scale for the guided sampling. - classifier_fn: A classifier function. Only used for the classifier guidance. - classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. - Returns: - A noise prediction model that accepts the noised data and the continuous time as the inputs. - """ - - def get_model_input_time(t_continuous): - """ - Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. - For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. - For continuous-time DPMs, we just use `t_continuous`. - """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * 1000. - else: - return t_continuous - - def noise_pred_fn(x, t_continuous, cond=None): - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - t_input = get_model_input_time(t_continuous) - if cond is None: - output = model(x, t_input, **model_kwargs) - else: - output = model(x, t_input, cond, **model_kwargs) - if model_type == "noise": - return output - elif model_type == "x_start": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) - elif model_type == "v": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x - elif model_type == "score": - sigma_t = noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return -expand_dims(sigma_t, dims) * output - - def cond_grad_fn(x, t_input): - """ - Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). - """ - with torch.enable_grad(): - x_in = x.detach().requires_grad_(True) - log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) - return torch.autograd.grad(log_prob.sum(), x_in)[0] - - def model_fn(x, t_continuous): - """ - The noise predicition model function that is used for DPM-Solver. - """ - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - if guidance_type == "uncond": - return noise_pred_fn(x, t_continuous) - elif guidance_type == "classifier": - assert classifier_fn is not None - t_input = get_model_input_time(t_continuous) - cond_grad = cond_grad_fn(x, t_input) - sigma_t = noise_schedule.marginal_std(t_continuous) - noise = noise_pred_fn(x, t_continuous) - return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad - elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: - return noise_pred_fn(x, t_continuous, cond=condition) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - if isinstance(condition, dict): - assert isinstance(unconditional_condition, dict) - c_in = dict() - for k in condition: - if isinstance(condition[k], list): - c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))] - else: - c_in[k] = torch.cat([unconditional_condition[k], condition[k]]) - else: - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - return noise_uncond + guidance_scale * (noise - noise_uncond) - - assert model_type in ["noise", "x_start", "v"] - assert guidance_type in ["uncond", "classifier", "classifier-free"] - return model_fn - - -class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): - """Construct a DPM-Solver. - We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). - If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). - If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). - In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. - The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. - Args: - model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): - `` - def model_fn(x, t_continuous): - return noise - `` - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. - thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. - max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. - - [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. - """ - self.model = model_fn - self.noise_schedule = noise_schedule - self.predict_x0 = predict_x0 - self.thresholding = thresholding - self.max_val = max_val - - def noise_prediction_fn(self, x, t): - """ - Return the noise prediction model. - """ - return self.model(x, t) - - def data_prediction_fn(self, x, t): - """ - Return the data prediction model (with thresholding). - """ - noise = self.noise_prediction_fn(x, t) - dims = x.dim() - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) - if self.thresholding: - p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / s - return x0 - - def model_fn(self, x, t): - """ - Convert the model to the noise prediction model or the data prediction model. - """ - if self.predict_x0: - return self.data_prediction_fn(x, t) - else: - return self.noise_prediction_fn(x, t) - - def get_time_steps(self, skip_type, t_T, t_0, N, device): - """Compute the intermediate time steps for sampling. - Args: - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - N: A `int`. The total number of the spacing of the time steps. - device: A torch device. - Returns: - A pytorch tensor of the time steps, with the shape (N + 1,). - """ - if skip_type == 'logSNR': - lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) - lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) - logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) - return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': - return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': - t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) - return t - else: - raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) - - def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): - """ - Get the order of each step for sampling by the singlestep DPM-Solver. - We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". - Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: - - If order == 1: - We take `steps` of DPM-Solver-1 (i.e. DDIM). - - If order == 2: - - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of DPM-Solver-2. - - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If order == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. - ============================================ - Args: - order: A `int`. The max order for the solver (2 or 3). - steps: A `int`. The total number of function evaluations (NFE). - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - device: A torch device. - Returns: - orders: A list of the solver order of each step. - """ - if order == 3: - K = steps // 3 + 1 - if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] - elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] - else: - orders = [3, ] * (K - 1) + [2] - elif order == 2: - if steps % 2 == 0: - K = steps // 2 - orders = [2, ] * K - else: - K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] - elif order == 1: - K = 1 - orders = [1, ] * steps - else: - raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': - # To reproduce the results in DPM-Solver paper - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) - else: - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders)).to(device)] - return timesteps_outer, orders - - def denoise_to_zero_fn(self, x, s): - """ - Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. - """ - return self.data_prediction_fn(x, s) - - def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): - """ - DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - if self.predict_x0: - phi_1 = torch.expm1(-h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - else: - phi_1 = torch.expm1(h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, - solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-2 from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the second-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 0.5 - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( - s1), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) - alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_1 = torch.expm1(-h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_1 = torch.expm1(h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) - ) - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1} - else: - return x_t - - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-3 from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). - If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 1. / 3. - if r2 is None: - r2 = 2. / 3. - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - lambda_s2 = lambda_s + r2 * h - s1 = ns.inverse_lambda(lambda_s1) - s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( - s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( - s2), ns.marginal_std(t) - alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_12 = torch.expm1(-r2 * h) - phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. - phi_2 = phi_1 / h + 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_12 = torch.expm1(r2 * h) - phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. - phi_2 = phi_1 / h - 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 - ) - - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} - else: - return x_t - - def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): - """ - Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - ns = self.noise_schedule - dims = x.dim() - model_prev_1, model_prev_0 = model_prev_list - t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( - t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - if self.predict_x0: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 - ) - else: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 - ) - return x_t - - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): - """ - Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - model_prev_2, model_prev_1, model_prev_0 = model_prev_list - t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( - t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_1 = lambda_prev_1 - lambda_prev_2 - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) - D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) - if self.predict_x0: - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 - ) - else: - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 - ) - return x_t - - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, - r2=None): - """ - Singlestep DPM-Solver with the order `order` from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - r1: A `float`. The hyperparameter of the second-order or third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) - elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) - elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): - """ - Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) - elif order == 2: - return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - elif order == 3: - return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): - """ - The adaptive step size solver based on singlestep DPM-Solver. - Args: - x: A pytorch tensor. The initial value at time `t_T`. - order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - h_init: A `float`. The initial step size (for logSNR). - atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. - rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. - theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. - t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the - current time and `t_0` is less than `t_err`. The default setting is 1e-5. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_0: A pytorch tensor. The approximated solution at time `t_0`. - [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. - """ - ns = self.noise_schedule - s = t_T * torch.ones((x.shape[0],)).to(x) - lambda_s = ns.marginal_lambda(s) - lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) - h = h_init * torch.ones_like(s).to(x) - x_prev = x - nfe = 0 - if order == 2: - r1 = 0.5 - lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) - elif order == 3: - r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) - else: - raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) - while torch.abs((s - t_0)).mean() > t_err: - t = ns.inverse_lambda(lambda_s + h) - x_lower, lower_noise_kwargs = lower_update(x, s, t) - x_higher = higher_update(x, s, t, **lower_noise_kwargs) - delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) - norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) - E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.): - x = x_higher - s = t - x_prev = x_lower - lambda_s = ns.marginal_lambda(s) - h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) - nfe += order - print('adaptive solver nfe', nfe) - return x - - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, - ): - """ - Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. - ===================================================== - We support the following algorithms for both noise prediction model and data prediction model: - - 'singlestep': - Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. - We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). - The total number of function evaluations (NFE) == `steps`. - Given a fixed NFE == `steps`, the sampling procedure is: - - If `order` == 1: - - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. - - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If `order` == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. - - 'multistep': - Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. - We initialize the first `order` values by lower order multistep solvers. - Given a fixed NFE == `steps`, the sampling procedure is: - Denote K = steps. - - If `order` == 1: - - We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. - - If `order` == 3: - - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. - - 'singlestep_fixed': - Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). - We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. - - 'adaptive': - Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). - We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. - You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs - (NFE) and the sample quality. - - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. - - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. - ===================================================== - Some advices for choosing the algorithm: - - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: - Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, - skip_type='time_uniform', method='singlestep') - - For **guided sampling with large guidance scale** by DPMs: - Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, - skip_type='time_uniform', method='multistep') - We support three types of `skip_type`: - - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** - - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. - - 'time_quadratic': quadratic time for the time steps. - ===================================================== - Args: - x: A pytorch tensor. The initial value at time `t_start` - e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. - steps: A `int`. The total number of function evaluations (NFE). - t_start: A `float`. The starting time of the sampling. - If `T` is None, we use self.noise_schedule.T (default is 1.0). - t_end: A `float`. The ending time of the sampling. - If `t_end` is None, we use 1. / self.noise_schedule.total_N. - e.g. if total_N == 1000, we have `t_end` == 1e-3. - For discrete-time DPMs: - - We recommend `t_end` == 1. / self.noise_schedule.total_N. - For continuous-time DPMs: - - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. - order: A `int`. The order of DPM-Solver. - skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. - method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. - denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. - Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). - This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and - score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID - for diffusion models sampling by diffusion SDEs for low-resolutional images - (such as CIFAR-10). However, we observed that such trick does not matter for - high-resolutional images. As it needs an additional NFE, we do not recommend - it for high-resolutional images. - lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. - Only valid for `method=multistep` and `steps < 15`. We empirically find that - this trick is a key to stabilizing the sampling by DPM-Solver with very few steps - (especially for steps <= 10). So we recommend to set it to be `True`. - solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. - atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - Returns: - x_end: A pytorch tensor. The approximated solution at time `t_end`. - """ - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start - device = x.device - if method == 'adaptive': - with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) - elif method == 'multistep': - assert steps >= order - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - assert timesteps.shape[0] - 1 == steps - with torch.no_grad(): - vec_t = timesteps[0].expand((x.shape[0])) - model_prev_list = [self.model_fn(x, vec_t)] - t_prev_list = [vec_t] - # Init the first `order` values by lower order multistep DPM-Solver. - for init_order in tqdm(range(1, order), desc="DPM init order"): - vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, - solver_type=solver_type) - model_prev_list.append(self.model_fn(x, vec_t)) - t_prev_list.append(vec_t) - # Compute the remaining values by `order`-th order multistep DPM-Solver. - for step in tqdm(range(order, steps + 1), desc="DPM multistep"): - vec_t = timesteps[step].expand(x.shape[0]) - if lower_order_final and steps < 15: - step_order = min(order, steps + 1 - step) - else: - step_order = order - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, - solver_type=solver_type) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t - # We do not need to evaluate the final model value. - if step < steps: - model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, - skip_type=skip_type, - t_T=t_T, t_0=t_0, - device=device) - elif method == 'singlestep_fixed': - K = steps // order - orders = [order, ] * K - timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) - for i, order in enumerate(orders): - t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) - lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) - vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) - h = lambda_inner[-1] - lambda_inner[0] - r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h - r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h - x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) - if denoise_to_zero: - x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) - return x - - -############################################################# -# other utility functions -############################################################# - -def interpolate_fn(x, xp, yp): - """ - A piecewise linear function y = f(x), using xp and yp as keypoints. - We implement f(x) in a differentiable way (i.e. applicable for autograd). - The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) - Args: - x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). - xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. - yp: PyTorch tensor with shape [C, K]. - Returns: - The function values f(x), with shape [N, C]. - """ - N, K = x.shape[0], xp.shape[1] - all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) - sorted_all_x, x_indices = torch.sort(all_x, dim=2) - x_idx = torch.argmin(x_indices, dim=2) - cand_start_idx = x_idx - 1 - start_idx = torch.where( - torch.eq(x_idx, 0), - torch.tensor(1, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) - start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) - end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) - start_idx2 = torch.where( - torch.eq(x_idx, 0), - torch.tensor(0, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) - start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) - end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) - cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) - return cand - - -def expand_dims(v, dims): - """ - Expand the tensor `v` to the dim `dims`. - Args: - `v`: a PyTorch tensor with shape [N]. - `dim`: a `int`. - Returns: - a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. - """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/comfy/ldm/models/diffusion/dpm_solver/sampler.py b/comfy/ldm/models/diffusion/dpm_solver/sampler.py deleted file mode 100644 index e4d0d0a38..000000000 --- a/comfy/ldm/models/diffusion/dpm_solver/sampler.py +++ /dev/null @@ -1,96 +0,0 @@ -"""SAMPLING ONLY.""" -import torch - -from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver - -MODEL_TYPES = { - "eps": "noise", - "v": "v" -} - - -class DPMSolverSampler(object): - def __init__(self, model, device=torch.device("cuda"), **kwargs): - super().__init__() - self.model = model - self.device = device - to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) - self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != self.device: - attr = attr.to(self.device) - setattr(self, name, attr) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] - if isinstance(ctmp, torch.Tensor): - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - elif isinstance(conditioning, list): - for ctmp in conditioning: - if ctmp.shape[0] != batch_size: - print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}") - else: - if isinstance(conditioning, torch.Tensor): - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - - print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') - - device = self.model.betas.device - if x_T is None: - img = torch.randn(size, device=device) - else: - img = x_T - - ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) - - model_fn = model_wrapper( - lambda x, t, c: self.model.apply_model(x, t, c), - ns, - model_type=MODEL_TYPES[self.model.parameterization], - guidance_type="classifier-free", - condition=conditioning, - unconditional_condition=unconditional_conditioning, - guidance_scale=unconditional_guidance_scale, - ) - - dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, - lower_order_final=True) - - return x.to(device), None diff --git a/comfy/ldm/models/diffusion/plms.py b/comfy/ldm/models/diffusion/plms.py deleted file mode 100644 index 9d31b3994..000000000 --- a/comfy/ldm/models/diffusion/plms.py +++ /dev/null @@ -1,245 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm -from functools import partial - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like -from ldm.models.diffusion.sampling_util import norm_thresholding - - -class PLMSSampler(object): - def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - self.device = device - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != self.device: - attr = attr.to(self.device) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - if ddim_eta != 0: - raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) - return samples, intermediates - - @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running PLMS Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) - old_eps = [] - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next, - dynamic_threshold=dynamic_threshold) - img, pred_x0, e_t = outs - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, - dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - if dynamic_threshold is not None: - pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t diff --git a/comfy/ldm/models/diffusion/sampling_util.py b/comfy/ldm/models/diffusion/sampling_util.py deleted file mode 100644 index 7eff02be6..000000000 --- a/comfy/ldm/models/diffusion/sampling_util.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -import numpy as np - - -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions. - From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] - - -def norm_thresholding(x0, value): - s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) - return x0 * (value / s) - - -def spatial_norm_thresholding(x0, value): - # b c h w - s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) - return x0 * (value / s) \ No newline at end of file diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index dcf467489..f68452382 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -5,8 +5,10 @@ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any +from functools import partial -from .diffusionmodules.util import checkpoint + +from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management @@ -160,32 +162,19 @@ def attention_sub_quad(query, key, value, heads, mask=None): mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) - chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD - kv_chunk_size_min = None + kv_chunk_size = None + query_chunk_size = None - #not sure at all about the math here - #TODO: tweak this - if mem_free_total > 8192 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 4 - elif mem_free_total > 4096 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 2 - else: - query_chunk_size_x = 1024 - kv_chunk_size_min_x = None - kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 - if kv_chunk_size_x < 1024: - kv_chunk_size_x = None + for x in [4096, 2048, 1024, 512, 256]: + count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0) + if count >= k_tokens: + kv_chunk_size = k_tokens + query_chunk_size = x + break - if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: - # the big matmul fits into our memory limit; do everything in 1 chunk, - # i.e. send it down the unchunked fast-path - query_chunk_size = q_tokens - kv_chunk_size = k_tokens - else: - query_chunk_size = query_chunk_size_x - kv_chunk_size = kv_chunk_size_x - kv_chunk_size_min = kv_chunk_size_min_x + if query_chunk_size is None: + query_chunk_size = 512 hidden_states = efficient_dot_product_attention( query, @@ -222,9 +211,14 @@ def attention_split(q, k, v, heads, mask=None): mem_free_total = model_management.get_free_memory(q.device) + if _ATTN_PRECISION =="fp32": + element_size = 4 + else: + element_size = q.element_size() + gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size + modifier = 3 mem_required = tensor_size * modifier steps = 1 @@ -252,10 +246,10 @@ def attention_split(q, k, v, heads, mask=None): s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale - first_op_done = True s2 = s1.softmax(dim=-1).to(v.dtype) del s1 + first_op_done = True r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 @@ -284,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None): ) return r1 +BROKEN_XFORMERS = False +try: + x_vers = xformers.__version__ + #I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error) + BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23") +except: + pass + def attention_xformers(q, k, v, heads, mask=None): b, _, dim_head = q.shape dim_head //= heads + if BROKEN_XFORMERS: + if b * heads > 65535: + return attention_pytorch(q, k, v, heads, mask) q, k, v = map( lambda t: t.unsqueeze(3) @@ -378,53 +383,72 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False, dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops): super().__init__() + + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) + self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) + self.disable_self_attn = disable_self_attn - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) - self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) - self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + context_dim_attn2 = None + if not switch_temporal_ca_to_sa: + context_dim_attn2 = context_dim + + self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, + heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none + self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + + self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa def forward(self, x, context=None, transformer_options={}): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): extra_options = {} - block = None - block_index = 0 - if "current_index" in transformer_options: - extra_options["transformer_index"] = transformer_options["current_index"] - if "block_index" in transformer_options: - block_index = transformer_options["block_index"] - extra_options["block_index"] = block_index - if "original_shape" in transformer_options: - extra_options["original_shape"] = transformer_options["original_shape"] - if "block" in transformer_options: - block = transformer_options["block"] - extra_options["block"] = block - if "cond_or_uncond" in transformer_options: - extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"] - if "patches" in transformer_options: - transformer_patches = transformer_options["patches"] - else: - transformer_patches = {} + block = transformer_options.get("block", None) + block_index = transformer_options.get("block_index", 0) + transformer_patches = {} + transformer_patches_replace = {} + + for k in transformer_options: + if k == "patches": + transformer_patches = transformer_options[k] + elif k == "patches_replace": + transformer_patches_replace = transformer_options[k] + else: + extra_options[k] = transformer_options[k] extra_options["n_heads"] = self.n_heads extra_options["dim_head"] = self.d_head - if "patches_replace" in transformer_options: - transformer_patches_replace = transformer_options["patches_replace"] - else: - transformer_patches_replace = {} + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip n = self.norm1(x) if self.disable_self_attn: @@ -473,31 +497,34 @@ class BasicTransformerBlock(nn.Module): for p in patch: x = p(x, extra_options) - n = self.norm2(x) - - context_attn2 = context - value_attn2 = None - if "attn2_patch" in transformer_patches: - patch = transformer_patches["attn2_patch"] - value_attn2 = context_attn2 - for p in patch: - n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) - - attn2_replace_patch = transformer_patches_replace.get("attn2", {}) - block_attn2 = transformer_block - if block_attn2 not in attn2_replace_patch: - block_attn2 = block - - if block_attn2 in attn2_replace_patch: - if value_attn2 is None: + if self.attn2 is not None: + n = self.norm2(x) + if self.switch_temporal_ca_to_sa: + context_attn2 = n + else: + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] value_attn2 = context_attn2 - n = self.attn2.to_q(n) - context_attn2 = self.attn2.to_k(context_attn2) - value_attn2 = self.attn2.to_v(value_attn2) - n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) - n = self.attn2.to_out(n) - else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + for p in patch: + n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) + + attn2_replace_patch = transformer_patches_replace.get("attn2", {}) + block_attn2 = transformer_block + if block_attn2 not in attn2_replace_patch: + block_attn2 = block + + if block_attn2 in attn2_replace_patch: + if value_attn2 is None: + value_attn2 = context_attn2 + n = self.attn2.to_q(n) + context_attn2 = self.attn2.to_k(context_attn2) + value_attn2 = self.attn2.to_v(value_attn2) + n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) + n = self.attn2.to_out(n) + else: + n = self.attn2(n, context=context_attn2, value=value_attn2) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] @@ -505,7 +532,12 @@ class BasicTransformerBlock(nn.Module): n = p(n, extra_options) x += n - x = self.ff(self.norm3(x)) + x + if self.is_res: + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + return x @@ -573,3 +605,164 @@ class SpatialTransformer(nn.Module): x = self.proj_out(x) return x + x_in + +class SpatialVideoTransformer(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + dtype=None, device=None, operations=comfy.ops + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + dtype=dtype, device=device, operations=operations + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_stack = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + # timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + dtype=dtype, device=device, operations=operations + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_stack) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_pos_embed = nn.Sequential( + operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device), + ) + + self.time_mixer = AlphaBlender( + alpha=merge_factor, merge_strategy=merge_strategy + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + transformer_options={} + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert ( + context.ndim == 3 + ), f"n dims of spatial context should be 3 but are {context.ndim}" + + if time_context is None: + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat( + time_context_first_timestep, "b ... -> (b n) ...", n=h * w + ) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype) + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate( + zip(self.transformer_blocks, self.time_stack) + ): + transformer_options["block_index"] = it_ + x = block( + x, + context=spatial_context, + transformer_options=transformer_options, + ) + + x_mix = x + x_mix = x_mix + emb + + B, S, C = x_mix.shape + x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) + x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = rearrange( + x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps + ) + + x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out + + diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index bf58a4045..48264892c 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -5,6 +5,8 @@ import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F +from einops import rearrange +from functools import partial from .util import ( checkpoint, @@ -12,8 +14,9 @@ from .util import ( zero_module, normalization, timestep_embedding, + AlphaBlender, ) -from ..attention import SpatialTransformer +from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists import comfy.ops @@ -28,6 +31,26 @@ class TimestepBlock(nn.Module): Apply the module to `x` given `emb` timestep embeddings. """ +#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index" +def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): + for layer in ts: + if isinstance(layer, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialVideoTransformer): + x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) + if "transformer_index" in transformer_options: + transformer_options["transformer_index"] += 1 + elif isinstance(layer, SpatialTransformer): + x = layer(x, context, transformer_options) + if "transformer_index" in transformer_options: + transformer_options["transformer_index"] += 1 + elif isinstance(layer, Upsample): + x = layer(x, output_shape=output_shape) + else: + x = layer(x) + return x class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ @@ -35,31 +58,8 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): support it as an extra input. """ - def forward(self, x, emb, context=None, transformer_options={}, output_shape=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context, transformer_options) - elif isinstance(layer, Upsample): - x = layer(x, output_shape=output_shape) - else: - x = layer(x) - return x - -#This is needed because accelerate makes a copy of transformer_options which breaks "current_index" -def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): - for layer in ts: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context, transformer_options) - transformer_options["current_index"] += 1 - elif isinstance(layer, Upsample): - x = layer(x, output_shape=output_shape) - else: - x = layer(x) - return x + def forward(self, *args, **kwargs): + return forward_timestep_embed(self, *args, **kwargs) class Upsample(nn.Module): """ @@ -154,6 +154,9 @@ class ResBlock(TimestepBlock): use_checkpoint=False, up=False, down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, dtype=None, device=None, operations=comfy.ops @@ -166,11 +169,17 @@ class ResBlock(TimestepBlock): self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, list): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 self.in_layers = nn.Sequential( nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), + operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), ) self.updown = up or down @@ -184,19 +193,24 @@ class ResBlock(TimestepBlock): else: self.h_upd = self.x_upd = nn.Identity() - self.emb_layers = nn.Sequential( - nn.SiLU(), - operations.Linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device - ), - ) + self.skip_t_emb = skip_t_emb + if self.skip_t_emb: + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + operations.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device + ), + ) self.out_layers = nn.Sequential( nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), zero_module( - operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) + operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) ), ) @@ -204,7 +218,7 @@ class ResBlock(TimestepBlock): self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = operations.conv_nd( - dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device + dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device ) else: self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) @@ -230,19 +244,110 @@ class ResBlock(TimestepBlock): h = in_conv(h) else: h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] + + emb_out = None + if not self.skip_t_emb: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift + h = out_norm(h) + if emb_out is not None: + scale, shift = th.chunk(emb_out, 2, dim=1) + h *= (1 + scale) + h += shift h = out_rest(h) else: - h = h + emb_out + if emb_out is not None: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h + +class VideoResBlock(ResBlock): + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + video_kernel_size=3, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + out_channels=None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + dtype=None, + device=None, + operations=comfy.ops + ): + super().__init__( + channels, + emb_channels, + dropout, + out_channels=out_channels, + use_conv=use_conv, + use_scale_shift_norm=use_scale_shift_norm, + dims=dims, + use_checkpoint=use_checkpoint, + up=up, + down=down, + dtype=dtype, + device=device, + operations=operations + ) + + self.time_stack = ResBlock( + default(out_channels, channels), + emb_channels, + dropout=dropout, + dims=3, + out_channels=default(out_channels, channels), + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=use_checkpoint, + exchange_temb_dims=True, + dtype=dtype, + device=device, + operations=operations + ) + self.time_mixer = AlphaBlender( + alpha=merge_factor, + merge_strategy=merge_strategy, + rearrange_pattern="b t -> b 1 t 1 1", + ) + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + num_video_frames: int, + image_only_indicator = None, + ) -> th.Tensor: + x = super().forward(x, emb) + + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + + x = self.time_stack( + x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) + ) + x = self.time_mixer( + x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator + ) + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + class Timestep(nn.Module): def __init__(self, dim): super().__init__() @@ -251,6 +356,15 @@ class Timestep(nn.Module): def forward(self, t): return timestep_embedding(t, self.dim) +def apply_control(h, control, name): + if control is not None and name in control and len(control[name]) > 0: + ctrl = control[name].pop() + if ctrl is not None: + try: + h += ctrl + except: + print("warning control could not be applied", h.shape, ctrl.shape) + return h class UNetModel(nn.Module): """ @@ -259,10 +373,6 @@ class UNetModel(nn.Module): :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and @@ -289,7 +399,6 @@ class UNetModel(nn.Module): model_channels, out_channels, num_res_blocks, - attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, @@ -314,6 +423,17 @@ class UNetModel(nn.Module): use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + transformer_depth_output=None, + use_temporal_resblock=False, + use_temporal_attention=False, + time_context_dim=None, + extra_ff_mix_layer=False, + use_spatial_context=False, + merge_strategy=None, + merge_factor=0.0, + video_kernel_size=None, + disable_temporal_crossattention=False, + max_ddpm_temb_period=10000, device=None, operations=comfy.ops, ): @@ -341,10 +461,7 @@ class UNetModel(nn.Module): self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - if transformer_depth_middle is None: - transformer_depth_middle = transformer_depth[-1] + if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: @@ -352,18 +469,16 @@ class UNetModel(nn.Module): raise ValueError("provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult") self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - self.attention_resolutions = attention_resolutions + transformer_depth = transformer_depth[:] + transformer_depth_output = transformer_depth_output[:] + self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample @@ -373,8 +488,12 @@ class UNetModel(nn.Module): self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample + self.use_temporal_resblocks = use_temporal_resblock self.predict_codebook_ids = n_embed is not None + self.default_num_video_frames = None + self.default_image_only_indicator = None + time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), @@ -411,13 +530,104 @@ class UNetModel(nn.Module): input_block_chans = [model_channels] ch = model_channels ds = 1 + + def get_attention_layer( + ch, + num_heads, + dim_head, + depth=1, + context_dim=None, + use_checkpoint=False, + disable_self_attn=False, + ): + if use_temporal_attention: + return SpatialVideoTransformer( + ch, + num_heads, + dim_head, + depth=depth, + context_dim=context_dim, + time_context_dim=time_context_dim, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=merge_strategy, + merge_factor=merge_factor, + checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + max_time_embed_period=max_ddpm_temb_period, + dtype=self.dtype, device=device, operations=operations + ) + else: + return SpatialTransformer( + ch, num_heads, dim_head, depth=depth, context_dim=context_dim, + disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + ) + + def get_resblock( + merge_factor, + merge_strategy, + video_kernel_size, + ch, + time_embed_dim, + dropout, + out_channels, + dims, + use_checkpoint, + use_scale_shift_norm, + down=False, + up=False, + dtype=None, + device=None, + operations=comfy.ops + ): + if self.use_temporal_resblocks: + return VideoResBlock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + dtype=dtype, + device=device, + operations=operations + ) + else: + return ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + use_checkpoint=use_checkpoint, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + dtype=dtype, + device=device, + operations=operations + ) + for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, @@ -428,7 +638,8 @@ class UNetModel(nn.Module): ) ] ch = mult * model_channels - if ds in attention_resolutions: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -443,11 +654,9 @@ class UNetModel(nn.Module): disabled_sa = False if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: - layers.append(SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations - ) + layers.append(get_attention_layer( + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch @@ -456,10 +665,13 @@ class UNetModel(nn.Module): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, @@ -488,35 +700,43 @@ class UNetModel(nn.Module): if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, + mid_block = [ + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, device=device, operations=operations - ), - SpatialTransformer( # always uses a self-attn + )] + if transformer_depth_middle >= 0: + mid_block += [get_attention_layer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint ), - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, device=device, operations=operations - ), - ) + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -524,10 +744,13 @@ class UNetModel(nn.Module): for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch + ich, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, @@ -538,7 +761,8 @@ class UNetModel(nn.Module): ) ] ch = model_channels * mult - if ds in attention_resolutions: + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -554,19 +778,21 @@ class UNetModel(nn.Module): if not exists(num_attention_blocks) or i < num_attention_blocks[level]: layers.append( - SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + get_attention_layer( + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint ) ) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, @@ -605,9 +831,13 @@ class UNetModel(nn.Module): :return: an [N x C x ...] Tensor of outputs. """ transformer_options["original_shape"] = list(x.shape) - transformer_options["current_index"] = 0 + transformer_options["transformer_index"] = 0 transformer_patches = transformer_options.get("patches", {}) + num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) + image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + time_context = kwargs.get("time_context", None) + assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -622,26 +852,28 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) - h = forward_timestep_embed(module, h, emb, context, transformer_options) - if control is not None and 'input' in control and len(control['input']) > 0: - ctrl = control['input'].pop() - if ctrl is not None: - h += ctrl + h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + h = apply_control(h, control, 'input') + if "input_block_patch" in transformer_patches: + patch = transformer_patches["input_block_patch"] + for p in patch: + h = p(h, transformer_options) + hs.append(h) + if "input_block_patch_after_skip" in transformer_patches: + patch = transformer_patches["input_block_patch_after_skip"] + for p in patch: + h = p(h, transformer_options) + transformer_options["block"] = ("middle", 0) - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) - if control is not None and 'middle' in control and len(control['middle']) > 0: - ctrl = control['middle'].pop() - if ctrl is not None: - h += ctrl + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) + h = apply_control(h, control, 'middle') + for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = hs.pop() - if control is not None and 'output' in control and len(control['output']) > 0: - ctrl = control['output'].pop() - if ctrl is not None: - hsp += ctrl + hsp = apply_control(hsp, control, 'output') if "output_block_patch" in transformer_patches: patch = transformer_patches["output_block_patch"] @@ -654,7 +886,7 @@ class UNetModel(nn.Module): output_shape = hs[-1].shape else: output_shape = None - h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index d890c8044..704bbe574 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -13,11 +13,78 @@ import math import torch import torch.nn as nn import numpy as np -from einops import repeat +from einops import repeat, rearrange from comfy.ldm.util import instantiate_from_config import comfy.ops +class AlphaBlender(nn.Module): + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", + ): + super().__init__() + self.merge_strategy = merge_strategy + self.rearrange_pattern = rearrange_pattern + + assert ( + merge_strategy in self.strategies + ), f"merge_strategy needs to be in {self.strategies}" + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif ( + self.merge_strategy == "learned" + or self.merge_strategy == "learned_with_images" + ): + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) + if self.merge_strategy == "fixed": + # make shape compatible + # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) + alpha = self.mix_factor + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + # make shape compatible + # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) + elif self.merge_strategy == "learned_with_images": + assert image_only_indicator is not None, "need image_only_indicator ..." + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + ) + alpha = rearrange(alpha, self.rearrange_pattern) + # make shape compatible + # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) + else: + raise NotImplementedError() + return alpha + + def forward( + self, + x_spatial, + x_temporal, + image_only_indicator=None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator) + x = ( + alpha.to(x_spatial.dtype) * x_spatial + + (1.0 - alpha).to(x_spatial.dtype) * x_temporal + ) + return x + + def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": betas = ( @@ -170,8 +237,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): if not repeat_only: half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half + ) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 4d42059b5..8e8e8054d 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -83,7 +83,8 @@ def _summarize_chunk( ) max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() - torch.exp(attn_weights - max_score, out=attn_weights) + attn_weights -= max_score + torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) max_score = max_score.squeeze(-1) diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py new file mode 100644 index 000000000..11ae049f3 --- /dev/null +++ b/comfy/ldm/modules/temporal_ae.py @@ -0,0 +1,244 @@ +import functools +from typing import Callable, Iterable, Union + +import torch +from einops import rearrange, repeat + +import comfy.ops + +from .diffusionmodules.model import ( + AttnBlock, + Decoder, + ResnetBlock, +) +from .diffusionmodules.openaimodel import ResBlock, timestep_embedding +from .attention import BasicTransformerBlock + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +class VideoResBlock(ResnetBlock): + def __init__( + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, + ): + super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) + if video_kernel_size is None: + video_kernel_size = [3, 1, 1] + self.time_stack = ResBlock( + channels=out_channels, + emb_channels=0, + dropout=dropout, + dims=3, + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=False, + skip_t_emb=True, + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, bs): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError() + + def forward(self, x, temb, skip_video=False, timesteps=None): + b, c, h, w = x.shape + if timesteps is None: + timesteps = b + + x = super().forward(x, temb) + + if not skip_video: + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = self.time_stack(x, temb) + + alpha = self.get_alpha(bs=b // timesteps) + x = alpha * x + (1.0 - alpha) * x_mix + + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class AE3DConv(torch.nn.Conv2d): + def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): + super().__init__(in_channels, out_channels, *args, **kwargs) + if isinstance(video_kernel_size, Iterable): + padding = [int(k // 2) for k in video_kernel_size] + else: + padding = int(video_kernel_size // 2) + + self.time_mix_conv = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=video_kernel_size, + padding=padding, + ) + + def forward(self, input, timesteps=None, skip_video=False): + if timesteps is None: + timesteps = input.shape[0] + x = super().forward(input) + if skip_video: + return x + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + x = self.time_mix_conv(x) + return rearrange(x, "b c t h w -> (b t) c h w") + + +class AttnVideoBlock(AttnBlock): + def __init__( + self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" + ): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = BasicTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + comfy.ops.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + comfy.ops.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps=None, skip_time_block=False): + if skip_time_block: + return super().forward(x) + + if timesteps is None: + timesteps = x.shape[0] + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + + +def make_time_attn( + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", +): + return partialclass( + AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy + ) + + +class Conv2DWrapper(torch.nn.Conv2d): + def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + return super().forward(input) + + +class VideoDecoder(Decoder): + available_time_modes = ["all", "conv-only", "attn-only"] + + def __init__( + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, + ): + self.video_kernel_size = video_kernel_size + self.alpha = alpha + self.merge_strategy = merge_strategy + self.time_mode = time_mode + assert ( + self.time_mode in self.available_time_modes + ), f"time_mode parameter has to be in {self.available_time_modes}" + + if self.time_mode != "attn-only": + kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) + if self.time_mode not in ["conv-only", "only-last-conv"]: + kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy) + if self.time_mode not in ["attn-only", "only-last-conv"]: + kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy) + + super().__init__(*args, **kwargs) + + def get_last_layer(self, skip_time_mix=False, **kwargs): + if self.time_mode == "attn-only": + raise NotImplementedError("TODO") + else: + return ( + self.conv_out.time_mix_conv.weight + if not skip_time_mix + else self.conv_out.weight + ) diff --git a/comfy/lora.py b/comfy/lora.py index 3009a1c9e..29c59d893 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -131,6 +131,18 @@ def load_lora(lora, to_load): loaded_keys.add(b_norm_name) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) + diff_name = "{}.diff".format(x) + diff_weight = lora.get(diff_name, None) + if diff_weight is not None: + patch_dict[to_load[x]] = (diff_weight,) + loaded_keys.add(diff_name) + + diff_bias_name = "{}.diff_b".format(x) + diff_bias = lora.get(diff_bias_name, None) + if diff_bias is not None: + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) + loaded_keys.add(diff_bias_name) + for x in lora.keys(): if x not in loaded_keys: print("lora key not loaded", x) @@ -141,9 +153,9 @@ def model_lora_keys_clip(model, key_map={}): text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False - for b in range(32): + for b in range(32): #TODO: clean up for c in LORA_CLIP_MAP: - k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k @@ -154,6 +166,8 @@ def model_lora_keys_clip(model, key_map={}): k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base key_map[lora_key] = k clip_l_present = True diff --git a/comfy/model_base.py b/comfy/model_base.py index cda6765e4..786c9cf47 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,16 +1,37 @@ import torch from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation -from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import comfy.model_management -import numpy as np +import comfy.conds from enum import Enum from . import utils class ModelType(Enum): EPS = 1 V_PREDICTION = 2 + V_PREDICTION_EDM = 3 + + +from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM + + +def model_sampling(model_config, model_type): + s = ModelSamplingDiscrete + + if model_type == ModelType.EPS: + c = EPS + elif model_type == ModelType.V_PREDICTION: + c = V_PREDICTION + elif model_type == ModelType.V_PREDICTION_EDM: + c = V_PREDICTION + s = ModelSamplingContinuousEDM + + class ModelSampling(s, c): + pass + + return ModelSampling(model_config) + class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None): @@ -19,10 +40,12 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config - self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + if not unet_config.get("disable_unet_model_creation", False): self.diffusion_model = UNetModel(**unet_config, device=device) self.model_type = model_type + self.model_sampling = model_sampling(model_config, model_type) + self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 @@ -30,38 +53,25 @@ class BaseModel(torch.nn.Module): print("model_type", model_type.name) print("adm", self.adm_channels) - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if given_betas is not None: - betas = given_betas - else: - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - - self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) - self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) - self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) - - def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): + def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + sigma = t + xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: - xc = torch.cat([x] + [c_concat], dim=1) - else: - xc = x + xc = torch.cat([xc] + [c_concat], dim=1) + context = c_crossattn dtype = self.get_dtype() xc = xc.to(dtype) - t = t.to(dtype) + t = self.model_sampling.timestep(t).float() context = context.to(dtype) - if c_adm is not None: - c_adm = c_adm.to(dtype) - return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float() + extra_conds = {} + for o in kwargs: + extra = kwargs[o] + if hasattr(extra, "to"): + extra = extra.to(dtype) + extra_conds[o] = extra + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) def get_dtype(self): return self.diffusion_model.dtype @@ -72,7 +82,8 @@ class BaseModel(torch.nn.Module): def encode_adm(self, **kwargs): return None - def cond_concat(self, **kwargs): + def extra_conds(self, **kwargs): + out = {} if self.inpaint_model: concat_keys = ("mask", "masked_image") cond_concat = [] @@ -101,8 +112,12 @@ class BaseModel(torch.nn.Module): cond_concat.append(torch.ones_like(noise)[:,:1]) elif ck == "masked_image": cond_concat.append(blank_inpaint_image_like(noise)) - return cond_concat - return None + data = torch.cat(cond_concat, dim=1) + out['c_concat'] = comfy.conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = comfy.conds.CONDRegular(adm) + return out def load_model_weights(self, sd, unet_prefix=""): to_load = {} @@ -111,6 +126,7 @@ class BaseModel(torch.nn.Module): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) + to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: print("unet missing:", m) @@ -147,6 +163,17 @@ class BaseModel(torch.nn.Module): def set_inpaint(self): self.inpaint_model = True + def memory_required(self, input_shape): + if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): + #TODO: this needs to be tweaked + area = input_shape[0] * input_shape[2] * input_shape[3] + return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024) + else: + #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. + area = input_shape[0] * input_shape[2] * input_shape[3] + return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) + + def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): adm_inputs = [] weights = [] @@ -241,3 +268,48 @@ class SDXL(BaseModel): out.append(self.embedder(torch.Tensor([target_width]))) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + +class SVD_img2vid(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder = Timestep(256) + + def encode_adm(self, **kwargs): + fps_id = kwargs.get("fps", 6) - 1 + motion_bucket_id = kwargs.get("motion_bucket_id", 127) + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.Tensor([fps_id]))) + out.append(self.embedder(torch.Tensor([motion_bucket_id]))) + out.append(self.embedder(torch.Tensor([augmentation]))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + + def extra_conds(self, **kwargs): + out = {} + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = comfy.conds.CONDRegular(adm) + + latent_image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if latent_image is None: + latent_image = torch.zeros_like(noise) + + if latent_image.shape[1:] != noise.shape[1:]: + latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) + + if "time_conditioning" in kwargs: + out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) + + out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) + out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0ff2e7fb5..c682c3e1a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -14,6 +14,20 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count +def calculate_transformer_depth(prefix, state_dict_keys, state_dict): + context_dim = None + use_linear_in_transformer = False + + transformer_prefix = prefix + "1.transformer_blocks." + transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) + if len(transformer_keys) > 0: + last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') + context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] + use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict + return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack + return None + def detect_unet_config(state_dict, key_prefix, dtype): state_dict_keys = list(state_dict.keys()) @@ -40,72 +54,95 @@ def detect_unet_config(state_dict, key_prefix, dtype): channel_mult = [] attention_resolutions = [] transformer_depth = [] + transformer_depth_output = [] context_dim = None use_linear_in_transformer = False + video_model = False current_res = 1 count = 0 last_res_blocks = 0 - last_transformer_depth = 0 last_channel_mult = 0 - while True: + input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.') + for count in range(input_block_count): prefix = '{}input_blocks.{}.'.format(key_prefix, count) + prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1) + block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys))) if len(block_keys) == 0: break + block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys))) + if "{}0.op.weight".format(prefix) in block_keys: #new layer - if last_transformer_depth > 0: - attention_resolutions.append(current_res) - transformer_depth.append(last_transformer_depth) num_res_blocks.append(last_res_blocks) channel_mult.append(last_channel_mult) current_res *= 2 last_res_blocks = 0 - last_transformer_depth = 0 last_channel_mult = 0 + out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) + if out is not None: + transformer_depth_output.append(out[0]) + else: + transformer_depth_output.append(0) else: res_block_prefix = "{}0.in_layers.0.weight".format(prefix) if res_block_prefix in block_keys: last_res_blocks += 1 last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels - transformer_prefix = prefix + "1.transformer_blocks." - transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) - if len(transformer_keys) > 0: - last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') - if context_dim is None: - context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] - use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + out = calculate_transformer_depth(prefix, state_dict_keys, state_dict) + if out is not None: + transformer_depth.append(out[0]) + if context_dim is None: + context_dim = out[1] + use_linear_in_transformer = out[2] + video_model = out[3] + else: + transformer_depth.append(0) + + res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output) + if res_block_prefix in block_keys_output: + out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) + if out is not None: + transformer_depth_output.append(out[0]) + else: + transformer_depth_output.append(0) - count += 1 - if last_transformer_depth > 0: - attention_resolutions.append(current_res) - transformer_depth.append(last_transformer_depth) num_res_blocks.append(last_res_blocks) channel_mult.append(last_channel_mult) - transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') - - if len(set(num_res_blocks)) == 1: - num_res_blocks = num_res_blocks[0] - - if len(set(transformer_depth)) == 1: - transformer_depth = transformer_depth[0] + if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: + transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') + else: + transformer_depth_middle = -1 unet_config["in_channels"] = in_channels unet_config["model_channels"] = model_channels unet_config["num_res_blocks"] = num_res_blocks - unet_config["attention_resolutions"] = attention_resolutions unet_config["transformer_depth"] = transformer_depth + unet_config["transformer_depth_output"] = transformer_depth_output unet_config["channel_mult"] = channel_mult unet_config["transformer_depth_middle"] = transformer_depth_middle unet_config['use_linear_in_transformer'] = use_linear_in_transformer unet_config["context_dim"] = context_dim + + if video_model: + unet_config["extra_ff_mix_layer"] = True + unet_config["use_spatial_context"] = True + unet_config["merge_strategy"] = "learned_with_images" + unet_config["merge_factor"] = 0.0 + unet_config["video_kernel_size"] = [3, 1, 1] + unet_config["use_temporal_resblock"] = True + unet_config["use_temporal_attention"] = True + else: + unet_config["use_temporal_resblock"] = False + unet_config["use_temporal_attention"] = False + return unet_config def model_config_from_unet_config(unet_config): @@ -124,19 +161,65 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma else: return model_config +def convert_config(unet_config): + new_config = unet_config.copy() + num_res_blocks = new_config.get("num_res_blocks", None) + channel_mult = new_config.get("channel_mult", None) + + if isinstance(num_res_blocks, int): + num_res_blocks = len(channel_mult) * [num_res_blocks] + + if "attention_resolutions" in new_config: + attention_resolutions = new_config.pop("attention_resolutions") + transformer_depth = new_config.get("transformer_depth", None) + transformer_depth_middle = new_config.get("transformer_depth_middle", None) + + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + if transformer_depth_middle is None: + transformer_depth_middle = transformer_depth[-1] + t_in = [] + t_out = [] + s = 1 + for i in range(len(num_res_blocks)): + res = num_res_blocks[i] + d = 0 + if s in attention_resolutions: + d = transformer_depth[i] + + t_in += [d] * res + t_out += [d] * (res + 1) + s *= 2 + transformer_depth = t_in + transformer_depth_output = t_out + new_config["transformer_depth"] = t_in + new_config["transformer_depth_output"] = t_out + new_config["transformer_depth_middle"] = transformer_depth_middle + + new_config["num_res_blocks"] = num_res_blocks + return new_config + + def unet_config_from_diffusers_unet(state_dict, dtype): match = {} - attention_resolutions = [] + transformer_depth = [] attn_res = 1 - for i in range(5): - k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i) - if k in state_dict: - match["context_dim"] = state_dict[k].shape[1] - attention_resolutions.append(attn_res) - attn_res *= 2 + down_blocks = count_blocks(state_dict, "down_blocks.{}") + for i in range(down_blocks): + attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}') + for ab in range(attn_blocks): + transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') + transformer_depth.append(transformer_count) + if transformer_count > 0: + match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1] - match["attention_resolutions"] = attention_resolutions + attn_res *= 2 + if attn_blocks == 0: + transformer_depth.append(0) + transformer_depth.append(0) + + match["transformer_depth"] = transformer_depth match["model_channels"] = state_dict["conv_in.weight"].shape[0] match["in_channels"] = state_dict["conv_in.weight"].shape[1] @@ -148,50 +231,65 @@ def unet_config_from_diffusers_unet(state_dict, dtype): SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} + 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4, + 'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], + 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, + 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} + 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, + 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, + 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} - SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} + SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, + 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8, + 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0, + 'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint] + SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] for unet_config in supported_models: matches = True @@ -200,7 +298,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype): matches = False break if matches: - return unet_config + return convert_config(unet_config) return None def model_config_from_diffusers_unet(state_dict, dtype): diff --git a/comfy/model_management.py b/comfy/model_management.py index 53582fc73..d4acd8950 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -133,6 +133,10 @@ else: import xformers import xformers.ops XFORMERS_IS_AVAILABLE = True + try: + XFORMERS_IS_AVAILABLE = xformers._has_cpp_library + except: + pass try: XFORMERS_VERSION = xformers.version.__version__ print("xformers version:", XFORMERS_VERSION) @@ -478,6 +482,21 @@ def text_encoder_device(): else: return torch.device("cpu") +def text_encoder_dtype(device=None): + if args.fp8_e4m3fn_text_enc: + return torch.float8_e4m3fn + elif args.fp8_e5m2_text_enc: + return torch.float8_e5m2 + elif args.fp16_text_enc: + return torch.float16 + elif args.fp32_text_enc: + return torch.float32 + + if should_use_fp16(device, prioritize_performance=False): + return torch.float16 + else: + return torch.float32 + def vae_device(): return get_torch_device() @@ -579,27 +598,6 @@ def get_free_memory(dev=None, torch_free_too=False): else: return mem_free_total -def batch_area_memory(area): - if xformers_enabled() or pytorch_attention_flash_attention(): - #TODO: these formulas are copied from maximum_batch_area below - return (area / 20) * (1024 * 1024) - else: - return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) - -def maximum_batch_area(): - global vram_state - if vram_state == VRAMState.NO_VRAM: - return 0 - - memory_free = get_free_memory() / (1024 * 1024) - if xformers_enabled() or pytorch_attention_flash_attention(): - #TODO: this needs to be tweaked - area = 20 * memory_free - else: - #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future - area = ((memory_free - 1024) * 0.9) / (0.6) - return int(max(area, 0)) - def cpu_mode(): global cpu_state return cpu_state == CPUState.CPU diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 50b725b86..a3cffc3be 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -6,11 +6,13 @@ import comfy.utils import comfy.model_management class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0, current_device=None): + def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): self.size = size self.model = model self.patches = {} self.backup = {} + self.object_patches = {} + self.object_patches_backup = {} self.model_options = {"transformer_options":{}} self.model_size() self.load_device = load_device @@ -20,6 +22,8 @@ class ModelPatcher: else: self.current_device = current_device + self.weight_inplace_update = weight_inplace_update + def model_size(self): if self.size > 0: return self.size @@ -33,11 +37,12 @@ class ModelPatcher: return size def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] + n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys return n @@ -47,6 +52,9 @@ class ModelPatcher: return True return False + def memory_required(self, input_shape): + return self.model.memory_required(input_shape=input_shape) + def set_model_sampler_cfg_function(self, sampler_cfg_function): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way @@ -88,9 +96,18 @@ class ModelPatcher: def set_model_attn2_output_patch(self, patch): self.set_model_patch(patch, "attn2_output_patch") + def set_model_input_block_patch(self, patch): + self.set_model_patch(patch, "input_block_patch") + + def set_model_input_block_patch_after_skip(self, patch): + self.set_model_patch(patch, "input_block_patch_after_skip") + def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") + def add_object_patch(self, name, obj): + self.object_patches[name] = obj + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: @@ -107,10 +124,10 @@ class ModelPatcher: for k in patch_list: if hasattr(patch_list[k], "to"): patch_list[k] = patch_list[k].to(device) - if "unet_wrapper_function" in self.model_options: - wrap_func = self.model_options["unet_wrapper_function"] + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] if hasattr(wrap_func, "to"): - self.model_options["unet_wrapper_function"] = wrap_func.to(device) + self.model_options["model_function_wrapper"] = wrap_func.to(device) def model_dtype(self): if hasattr(self.model, "get_dtype"): @@ -128,6 +145,7 @@ class ModelPatcher: return list(p) def get_key_patches(self, filter_prefix=None): + comfy.model_management.unload_model_clones(self) model_sd = self.model_state_dict() p = {} for k in model_sd: @@ -150,6 +168,12 @@ class ModelPatcher: return sd def patch_model(self, device_to=None): + for k in self.object_patches: + old = getattr(self.model, k) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old + setattr(self.model, k, self.object_patches[k]) + model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: @@ -158,15 +182,20 @@ class ModelPatcher: weight = model_sd[key] + inplace_update = self.weight_inplace_update + if key not in self.backup: - self.backup[key] = weight.to(self.offload_device) + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) if device_to is not None: temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - comfy.utils.set_attr(self.model, key, out_weight) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr(self.model, key, out_weight) del temp_weight if device_to is not None: @@ -282,11 +311,21 @@ class ModelPatcher: def unpatch_model(self, device_to=None): keys = list(self.backup.keys()) - for k in keys: - comfy.utils.set_attr(self.model, k, self.backup[k]) + if self.weight_inplace_update: + for k in keys: + comfy.utils.copy_to_param(self.model, k, self.backup[k]) + else: + for k in keys: + comfy.utils.set_attr(self.model, k, self.backup[k]) self.backup = {} if device_to is not None: self.model.to(device_to) self.current_device = device_to + + keys = list(self.object_patches_backup.keys()) + for k in keys: + setattr(self.model, k, self.object_patches_backup[k]) + + self.object_patches_backup = {} diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py new file mode 100644 index 000000000..69c8b1f01 --- /dev/null +++ b/comfy/model_sampling.py @@ -0,0 +1,129 @@ +import torch +import numpy as np +from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule +import math + +class EPS: + def calculate_input(self, sigma, noise): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + +class V_PREDICTION(EPS): + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + + +class ModelSamplingDiscrete(torch.nn.Module): + def __init__(self, model_config=None): + super().__init__() + beta_schedule = "linear" + if model_config is not None: + beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule) + self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + self.sigma_data = 1.0 + + def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) + # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + + # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + + sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + self.set_sigmas(sigmas) + + def set_sigmas(self, sigmas): + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) + + def sigma(self, timestep): + t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1)) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp().to(timestep.device) + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + percent = 1.0 - percent + return self.sigma(torch.tensor(percent * 999.0)).item() + + +class ModelSamplingContinuousEDM(torch.nn.Module): + def __init__(self, model_config=None): + super().__init__() + self.sigma_data = 1.0 + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + sigma_min = sampling_settings.get("sigma_min", 0.002) + sigma_max = sampling_settings.get("sigma_max", 120.0) + self.set_sigma_range(sigma_min, sigma_max) + + def set_sigma_range(self, sigma_min, sigma_max): + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() + + self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return 0.25 * sigma.log() + + def sigma(self, timestep): + return (timestep / 0.25).exp() + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + percent = 1.0 - percent + + log_sigma_min = math.log(self.sigma_min) + return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) diff --git a/comfy/ops.py b/comfy/ops.py index 610d54584..0bfb698aa 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,29 +1,23 @@ import torch from contextlib import contextmanager -class Linear(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) - if bias: - self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter('bias', None) - - def forward(self, input): - return torch.nn.functional.linear(input, self.weight, self.bias) +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None class Conv2d(torch.nn.Conv2d): def reset_parameters(self): return None +class Conv3d(torch.nn.Conv3d): + def reset_parameters(self): + return None + def conv_nd(dims, *args, **kwargs): if dims == 2: return Conv2d(*args, **kwargs) + elif dims == 3: + return Conv3d(*args, **kwargs) else: raise ValueError(f"unsupported dimensions: {dims}") diff --git a/comfy/sample.py b/comfy/sample.py index e6a69973d..034db97ee 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,6 +1,7 @@ import torch import comfy.model_management import comfy.samplers +import comfy.conds import comfy.utils import math import numpy as np @@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device): noise_mask = noise_mask.to(device) return noise_mask -def broadcast_cond(cond, batch, device): - """broadcasts conditioning to the batch size""" - copy = [] - for p in cond: - t = comfy.utils.repeat_to_batch_size(p[0], batch) - t = t.to(device) - copy += [[t] + p[1:]] - return copy - def get_models_from_cond(cond, model_type): models = [] for c in cond: - if model_type in c[1]: - models += [c[1][model_type]] + if model_type in c: + models += [c[model_type]] return models +def convert_cond(cond): + out = [] + for c in cond: + temp = c[1].copy() + model_conds = temp.get("model_conds", {}) + if c[0] is not None: + model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) + temp["model_conds"] = model_conds + out.append(temp) + return out + def get_additional_models(positive, negative, dtype): """loads additional models in positive and negative conditioning""" control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) @@ -72,18 +75,18 @@ def cleanup_additional_models(models): def prepare_sampling(model, noise_shape, positive, negative, noise_mask): device = model.load_device + positive = convert_cond(positive) + negative = convert_cond(negative) if noise_mask is not None: noise_mask = prepare_mask(noise_mask, noise_shape, device) real_model = None models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory) + comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) real_model = model.model - positive_copy = broadcast_cond(positive, noise_shape[0], device) - negative_copy = broadcast_cond(negative, noise_shape[0], device) - return real_model, positive_copy, negative_copy, noise_mask, models + return real_model, positive, negative, noise_mask, models def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): diff --git a/comfy/samplers.py b/comfy/samplers.py index 0b38fbd1e..1d012a514 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,48 +1,42 @@ from .k_diffusion import sampling as k_diffusion_sampling -from .k_diffusion import external as k_diffusion_external from .extra_samplers import uni_pc import torch +import enum from comfy import model_management -from .ldm.models.diffusion.ddim import DDIMSampler -from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math from comfy import model_base import comfy.utils +import comfy.conds -def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) - return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers -#Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - def get_area_and_mult(cond, x_in, timestep_in): +#Returns denoised +def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): + def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 - if 'timestep_start' in cond[1]: - timestep_start = cond[1]['timestep_start'] + + if 'timestep_start' in conds: + timestep_start = conds['timestep_start'] if timestep_in[0] > timestep_start: return None - if 'timestep_end' in cond[1]: - timestep_end = cond[1]['timestep_end'] + if 'timestep_end' in conds: + timestep_end = conds['timestep_end'] if timestep_in[0] < timestep_end: return None - if 'area' in cond[1]: - area = cond[1]['area'] - if 'strength' in cond[1]: - strength = cond[1]['strength'] - - adm_cond = None - if 'adm_encoded' in cond[1]: - adm_cond = cond[1]['adm_encoded'] + if 'area' in conds: + area = conds['area'] + if 'strength' in conds: + strength = conds['strength'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in cond[1]: + if 'mask' in conds: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process mask_strength = 1.0 - if "mask_strength" in cond[1]: - mask_strength = cond[1]["mask_strength"] - mask = cond[1]['mask'] + if "mask_strength" in conds: + mask_strength = conds["mask_strength"] + mask = conds['mask'] assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength @@ -51,7 +45,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod mask = torch.ones_like(input_x) mult = mask * strength - if 'mask' not in cond[1]: + if 'mask' not in conds: rr = 8 if area[2] != 0: for t in range(rr): @@ -67,27 +61,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) conditionning = {} - conditionning['c_crossattn'] = cond[0] - - if 'concat' in cond[1]: - cond_concat_in = cond[1]['concat'] - if cond_concat_in is not None and len(cond_concat_in) > 0: - cropped = [] - for x in cond_concat_in: - cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - cropped.append(cr) - conditionning['c_concat'] = torch.cat(cropped, dim=1) - - if adm_cond is not None: - conditionning['c_adm'] = adm_cond + model_conds = conds["model_conds"] + for c in model_conds: + conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) control = None - if 'control' in cond[1]: - control = cond[1]['control'] + if 'control' in conds: + control = conds['control'] patches = None - if 'gligen' in cond[1]: - gligen = cond[1]['gligen'] + if 'gligen' in conds: + gligen = conds['gligen'] patches = {} gligen_type = gligen[0] gligen_model = gligen[1] @@ -105,22 +89,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod return True if c1.keys() != c2.keys(): return False - if 'c_crossattn' in c1: - s1 = c1['c_crossattn'].shape - s2 = c2['c_crossattn'].shape - if s1 != s2: - if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen - return False - - mult_min = lcm(s1[1], s2[1]) - diff = mult_min // min(s1[1], s2[1]) - if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much - return False - if 'c_concat' in c1: - if c1['c_concat'].shape != c2['c_concat'].shape: - return False - if 'c_adm' in c1: - if c1['c_adm'].shape != c2['c_adm'].shape: + for k in c1: + if not c1[k].can_concat(c2[k]): return False return True @@ -149,39 +119,27 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod c_concat = [] c_adm = [] crossattn_max_len = 0 - for x in c_list: - if 'c_crossattn' in x: - c = x['c_crossattn'] - if crossattn_max_len == 0: - crossattn_max_len = c.shape[1] - else: - crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) - c_crossattn.append(c) - if 'c_concat' in x: - c_concat.append(x['c_concat']) - if 'c_adm' in x: - c_adm.append(x['c_adm']) - out = {} - c_crossattn_out = [] - for c in c_crossattn: - if c.shape[1] < crossattn_max_len: - c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result - c_crossattn_out.append(c) - if len(c_crossattn_out) > 0: - out['c_crossattn'] = torch.cat(c_crossattn_out) - if len(c_concat) > 0: - out['c_concat'] = torch.cat(c_concat) - if len(c_adm) > 0: - out['c_adm'] = torch.cat(c_adm) + temp = {} + for x in c_list: + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur + + out = {} + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) + return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in)/100000.0 + out_count = torch.ones_like(x_in) * 1e-37 out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in)/100000.0 + out_uncond_count = torch.ones_like(x_in) * 1e-37 COND = 0 UNCOND = 1 @@ -212,9 +170,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod to_batch_temp.reverse() to_batch = to_batch_temp[:1] + free_memory = model_management.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] - if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area): + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) < free_memory: to_batch = batch_amount break @@ -260,12 +220,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod transformer_options["patches"] = patches transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["sigmas"] = timestep + c['transformer_options'] = transformer_options if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: - output = model_function(input_x, timestep_, **c).chunk(batch_chunks) + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) del input_x for o in range(batch_chunks): @@ -281,39 +243,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod del out_count out_uncond /= out_uncond_count del out_uncond_count - return out_cond, out_uncond - max_total_area = model_management.maximum_batch_area() if math.isclose(cond_scale, 1.0): uncond = None - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options) + cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) if "sampler_cfg_function" in model_options: - args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} - return model_options["sampler_cfg_function"](args) + args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} + return x - model_options["sampler_cfg_function"](args) else: return uncond + (cond - uncond) * cond_scale - -class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): - def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod, quantize=quantize) - - def get_v(self, x, t, cond, **kwargs): - return self.inner_model.apply_model(x, t, cond, **kwargs) - - class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - self.alphas_cumprod = model.alphas_cumprod def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) + out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) return out - + def forward(self, *args, **kwargs): + return self.apply_model(*args, **kwargs) class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): @@ -332,32 +283,40 @@ class KSamplerX0Inpaint(torch.nn.Module): return out def simple_scheduler(model, steps): + s = model.model_sampling sigs = [] - ss = len(model.sigmas) / steps + ss = len(s.sigmas) / steps for x in range(steps): - sigs += [float(model.sigmas[-(1 + int(x * ss))])] + sigs += [float(s.sigmas[-(1 + int(x * ss))])] sigs += [0.0] return torch.FloatTensor(sigs) def ddim_scheduler(model, steps): + s = model.model_sampling sigs = [] - ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False) - for x in range(len(ddim_timesteps) - 1, -1, -1): - ts = ddim_timesteps[x] - if ts > 999: - ts = 999 - sigs.append(model.t_to_sigma(torch.tensor(ts))) + ss = len(s.sigmas) // steps + x = 1 + while x < len(s.sigmas): + sigs += [float(s.sigmas[x])] + x += ss + sigs = sigs[::-1] sigs += [0.0] return torch.FloatTensor(sigs) -def sgm_scheduler(model, steps): +def normal_scheduler(model, steps, sgm=False, floor=False): + s = model.model_sampling + start = s.timestep(s.sigma_max) + end = s.timestep(s.sigma_min) + + if sgm: + timesteps = torch.linspace(start, end, steps + 1)[:-1] + else: + timesteps = torch.linspace(start, end, steps) + sigs = [] - timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int) for x in range(len(timesteps)): ts = timesteps[x] - if ts > 999: - ts = 999 - sigs.append(model.t_to_sigma(torch.tensor(ts))) + sigs.append(s.sigma(ts)) sigs += [0.0] return torch.FloatTensor(sigs) @@ -389,19 +348,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): # While we're doing this, we can also resolve the mask device and scaling for performance reasons for i in range(len(conditions)): c = conditions[i] - if 'area' in c[1]: - area = c[1]['area'] + if 'area' in c: + area = c['area'] if area[0] == "percentage": - modified = c[1].copy() + modified = c.copy() area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w)) modified['area'] = area - c = [c[0], modified] + c = modified conditions[i] = c - if 'mask' in c[1]: - mask = c[1]['mask'] + if 'mask' in c: + mask = c['mask'] mask = mask.to(device=device) - modified = c[1].copy() + modified = c.copy() if len(mask.shape) == 2: mask = mask.unsqueeze(0) if mask.shape[1] != h or mask.shape[2] != w: @@ -422,66 +381,70 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): modified['area'] = area modified['mask'] = mask - conditions[i] = [c[0], modified] + conditions[i] = modified def create_cond_with_same_area_if_none(conds, c): - if 'area' not in c[1]: + if 'area' not in c: return - c_area = c[1]['area'] + c_area = c['area'] smallest = None for x in conds: - if 'area' in x[1]: - a = x[1]['area'] + if 'area' in x: + a = x['area'] if c_area[2] >= a[2] and c_area[3] >= a[3]: if a[0] + a[2] >= c_area[0] + c_area[2]: if a[1] + a[3] >= c_area[1] + c_area[3]: if smallest is None: smallest = x - elif 'area' not in smallest[1]: + elif 'area' not in smallest: smallest = x else: - if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]: + if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]: smallest = x else: if smallest is None: smallest = x if smallest is None: return - if 'area' in smallest[1]: - if smallest[1]['area'] == c_area: + if 'area' in smallest: + if smallest['area'] == c_area: return - n = c[1].copy() - conds += [[smallest[0], n]] + + out = c.copy() + out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied? + conds += [out] def calculate_start_end_timesteps(model, conds): + s = model.model_sampling for t in range(len(conds)): x = conds[t] timestep_start = None timestep_end = None - if 'start_percent' in x[1]: - timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0))) - if 'end_percent' in x[1]: - timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0))) + if 'start_percent' in x: + timestep_start = s.percent_to_sigma(x['start_percent']) + if 'end_percent' in x: + timestep_end = s.percent_to_sigma(x['end_percent']) if (timestep_start is not None) or (timestep_end is not None): - n = x[1].copy() + n = x.copy() if (timestep_start is not None): n['timestep_start'] = timestep_start if (timestep_end is not None): n['timestep_end'] = timestep_end - conds[t] = [x[0], n] + conds[t] = n def pre_run_control(model, conds): + s = model.model_sampling for t in range(len(conds)): x = conds[t] timestep_start = None timestep_end = None - percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) - if 'control' in x[1]: - x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) + percent_to_timestep_function = lambda a: s.percent_to_sigma(a) + if 'control' in x: + x['control'].pre_run(model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -490,16 +453,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): uncond_other = [] for t in range(len(conds)): x = conds[t] - if 'area' not in x[1]: - if name in x[1] and x[1][name] is not None: - cond_cnets.append(x[1][name]) + if 'area' not in x: + if name in x and x[name] is not None: + cond_cnets.append(x[name]) else: cond_other.append((x, t)) for t in range(len(uncond)): x = uncond[t] - if 'area' not in x[1]: - if name in x[1] and x[1][name] is not None: - uncond_cnets.append(x[1][name]) + if 'area' not in x: + if name in x and x[name] is not None: + uncond_cnets.append(x[name]) else: uncond_other.append((x, t)) @@ -509,47 +472,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): for x in range(len(cond_cnets)): temp = uncond_other[x % len(uncond_other)] o = temp[0] - if name in o[1] and o[1][name] is not None: - n = o[1].copy() + if name in o and o[name] is not None: + n = o.copy() n[name] = uncond_fill_func(cond_cnets, x) - uncond += [[o[0], n]] + uncond += [n] else: - n = o[1].copy() + n = o.copy() n[name] = uncond_fill_func(cond_cnets, x) - uncond[temp[1]] = [o[0], n] + uncond[temp[1]] = n -def encode_adm(model, conds, batch_size, width, height, device, prompt_type): +def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs): for t in range(len(conds)): x = conds[t] - adm_out = None - if 'adm' in x[1]: - adm_out = x[1]["adm"] - else: - params = x[1].copy() - params["width"] = params.get("width", width * 8) - params["height"] = params.get("height", height * 8) - params["prompt_type"] = params.get("prompt_type", prompt_type) - adm_out = model.encode_adm(device=device, **params) - - if adm_out is not None: - x[1] = x[1].copy() - x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device) - - return conds - -def encode_cond(model_function, key, conds, device, **kwargs): - for t in range(len(conds)): - x = conds[t] - params = x[1].copy() + params = x.copy() params["device"] = device + params["noise"] = noise + params["width"] = params.get("width", noise.shape[3] * 8) + params["height"] = params.get("height", noise.shape[2] * 8) + params["prompt_type"] = params.get("prompt_type", prompt_type) for k in kwargs: if k not in params: params[k] = kwargs[k] out = model_function(**params) - if out is not None: - x[1] = x[1].copy() - x[1][key] = out + x = x.copy() + model_conds = x['model_conds'].copy() + for k in out: + model_conds[k] = out[k] + x['model_conds'] = model_conds + conds[t] = x return conds class Sampler: @@ -557,95 +508,79 @@ class Sampler: pass def max_denoise(self, model_wrap, sigmas): - return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05) - -class DDIM(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - timesteps = [] - for s in range(sigmas.shape[0]): - timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s])) - noise_mask = None - if denoise_mask is not None: - noise_mask = 1.0 - denoise_mask - - ddim_callback = None - if callback is not None: - total_steps = len(timesteps) - 1 - ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) - - max_denoise = self.max_denoise(model_wrap, sigmas) - - ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device) - ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) - z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise) - samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps, - batch_size=noise.shape[0], - shape=noise.shape[1:], - verbose=False, - eta=0.0, - x_T=z_enc, - x0=latent_image, - img_callback=ddim_callback, - denoise_function=model_wrap.predict_eps_discrete_timestep, - extra_args=extra_args, - mask=noise_mask, - to_zero=sigmas[-1]==0, - end_step=sigmas.shape[0] - 1, - disable_pbar=disable_pbar) - return samples + max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma class UNIPC(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) + return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) class UNIPCBH2(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) + return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) -KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", +KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] -def ksampler(sampler_name, extra_options={}): - class KSAMPLER(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - extra_args["denoise_mask"] = denoise_mask - model_k = KSamplerX0Inpaint(model_wrap) - model_k.latent_image = latent_image +class KSAMPLER(Sampler): + def __init__(self, sampler_function, extra_options={}, inpaint_options={}): + self.sampler_function = sampler_function + self.extra_options = extra_options + self.inpaint_options = inpaint_options + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + extra_args["denoise_mask"] = denoise_mask + model_k = KSamplerX0Inpaint(model_wrap) + model_k.latent_image = latent_image + if self.inpaint_options.get("random", False): #TODO: Should this be the default? + generator = torch.manual_seed(extra_args.get("seed", 41) + 1) + model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) + else: model_k.noise = noise - if self.max_denoise(model_wrap, sigmas): - noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) - else: - noise = noise * sigmas[0] + if self.max_denoise(model_wrap, sigmas): + noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) + else: + noise = noise * sigmas[0] - k_callback = None - total_steps = len(sigmas) - 1 - if callback is not None: - k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + k_callback = None + total_steps = len(sigmas) - 1 + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + if latent_image is not None: + noise += latent_image + + samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + return samples + + +def ksampler(sampler_name, extra_options={}, inpaint_options={}): + if sampler_name == "dpm_fast": + def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable): sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] + total_steps = len(sigmas) - 1 + return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) + sampler_function = dpm_fast_function + elif sampler_name == "dpm_adaptive": + def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): + sigma_min = sigmas[-1] + if sigma_min == 0: + sigma_min = sigmas[-2] + return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) + sampler_function = dpm_adaptive_function + else: + sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) - if latent_image is not None: - noise += latent_image - if sampler_name == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) - elif sampler_name == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) - else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options) - return samples - return KSAMPLER + return KSAMPLER(sampler_function, extra_options, inpaint_options) def wrap_model(model): model_denoise = CFGNoisePredictor(model) - if model.model_type == model_base.ModelType.V_PREDICTION: - model_wrap = CompVisVDenoiser(model_denoise, quantize=True) - else: - model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True) - return model_wrap + return model_denoise def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): positive = positive[:] @@ -656,8 +591,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model model_wrap = wrap_model(model) - calculate_start_end_timesteps(model_wrap, negative) - calculate_start_end_timesteps(model_wrap, positive) + calculate_start_end_timesteps(model, negative) + calculate_start_end_timesteps(model, positive) #make sure each cond area has an opposite one with the same area for c in positive: @@ -665,21 +600,17 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model for c in negative: create_cond_with_same_area_if_none(positive, c) - pre_run_control(model_wrap, negative + positive) + pre_run_control(model, negative + positive) - apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if latent_image is not None: latent_image = model.process_latent_in(latent_image) - if model.is_adm(): - positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") - negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - - if hasattr(model, 'cond_concat'): - positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} @@ -690,30 +621,29 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", " SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] def calculate_sigmas_scheduler(model, scheduler_name, steps): - model_wrap = wrap_model(model) if scheduler_name == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) elif scheduler_name == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) elif scheduler_name == "normal": - sigmas = model_wrap.get_sigmas(steps) + sigmas = normal_scheduler(model, steps) elif scheduler_name == "simple": - sigmas = simple_scheduler(model_wrap, steps) + sigmas = simple_scheduler(model, steps) elif scheduler_name == "ddim_uniform": - sigmas = ddim_scheduler(model_wrap, steps) + sigmas = ddim_scheduler(model, steps) elif scheduler_name == "sgm_uniform": - sigmas = sgm_scheduler(model_wrap, steps) + sigmas = normal_scheduler(model, steps, sgm=True) else: print("error invalid scheduler", self.scheduler) return sigmas -def sampler_class(name): +def sampler_object(name): if name == "uni_pc": - sampler = UNIPC + sampler = UNIPC() elif name == "uni_pc_bh2": - sampler = UNIPCBH2 + sampler = UNIPCBH2() elif name == "ddim": - sampler = DDIM + sampler = ksampler("euler", inpaint_options={"random": True}) else: sampler = ksampler(name) return sampler @@ -776,6 +706,6 @@ class KSampler: else: return torch.zeros_like(noise) - sampler = sampler_class(self.sampler) + sampler = sampler_object(self.sampler) - return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) diff --git a/comfy/sd.py b/comfy/sd.py index c364b723c..f4f84d0a0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -23,6 +23,7 @@ import comfy.model_patcher import comfy.lora import comfy.t2i_adapter.adapter import comfy.supported_models_base +import comfy.taesd.taesd def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -55,13 +56,26 @@ def load_clip_weights(model, sd): def load_lora_for_models(model, clip, lora, strength_model, strength_clip): - key_map = comfy.lora.model_lora_keys_unet(model.model) - key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + loaded = comfy.lora.load_lora(lora, key_map) - new_modelpatcher = model.clone() - k = new_modelpatcher.add_patches(loaded, strength_model) - new_clip = clip.clone() - k1 = new_clip.add_patches(loaded, strength_clip) + if model is not None: + new_modelpatcher = model.clone() + k = new_modelpatcher.add_patches(loaded, strength_model) + else: + k = () + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + k1 = new_clip.add_patches(loaded, strength_clip) + else: + k1 = () + new_clip = None k = set(k) k1 = set(k1) for x in loaded: @@ -82,10 +96,7 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() params['device'] = offload_device - if model_management.should_use_fp16(load_device, prioritize_performance=False): - params['dtype'] = torch.float16 - else: - params['dtype'] = torch.float32 + params['dtype'] = model_management.text_encoder_dtype(load_device) self.cond_stage_model = clip(**(params)) @@ -144,10 +155,24 @@ class VAE: if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) + self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) + if config is None: - #default SD1.x/SD2.x VAE parameters - ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) + if "decoder.mid.block_1.mix_factor" in sd: + encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + decoder_config = encoder_config.copy() + decoder_config["video_kernel_size"] = [3, 1, 1] + decoder_config["alpha"] = 0.0 + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, + decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) + elif "taesd_decoder.1.weight" in sd: + self.first_stage_model = comfy.taesd.taesd.TAESD() + else: + #default SD1.x/SD2.x VAE parameters + ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() @@ -162,10 +187,12 @@ class VAE: if device is None: device = model_management.vae_device() self.device = device - self.offload_device = model_management.vae_offload_device() + offload_device = model_management.vae_offload_device() self.vae_dtype = model_management.vae_dtype() self.first_stage_model.to(self.vae_dtype) + self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -194,10 +221,9 @@ class VAE: return samples def decode(self, samples_in): - self.first_stage_model = self.first_stage_model.to(self.device) try: - memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 - model_management.free_memory(memory_used, self.device) + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -210,22 +236,19 @@ class VAE: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) - self.first_stage_model = self.first_stage_model.to(self.offload_device) pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): - self.first_stage_model = self.first_stage_model.to(self.device) + model_management.load_model_gpu(self.patcher) output = self.decode_tiled_(samples, tile_x, tile_y, overlap) - self.first_stage_model = self.first_stage_model.to(self.offload_device) return output.movedim(1,-1) def encode(self, pixel_samples): - self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: - memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. - model_management.free_memory(memory_used, self.device) + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -238,14 +261,12 @@ class VAE: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples) - self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): - self.first_stage_model = self.first_stage_model.to(self.device) + model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) - self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def get_sd(self): @@ -360,7 +381,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl from . import latent_formats model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) - model_config.unet_config = unet_config + model_config.unet_config = model_detection.convert_config(unet_config) if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) @@ -388,11 +409,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_h elif clip_config["target"].endswith("FrozenCLIPEmbedder"): clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_l load_clip_weights(w, state_dict) return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) @@ -429,6 +452,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) vae = VAE(sd=vae_sd) if output_clip: @@ -453,20 +477,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o return (model_patcher, clip, vae, clipvision) -def load_unet(unet_path): #load unet in diffusers format - sd = comfy.utils.load_torch_file(unet_path) +def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) if "input_blocks.0.0.weight" in sd: #ldm model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) if model_config is None: - raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + return None new_sd = sd else: #diffusers model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) if model_config is None: - print("ERROR UNSUPPORTED UNET", unet_path) return None diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) @@ -481,8 +503,19 @@ def load_unet(unet_path): #load unet in diffusers format model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") + left_over = sd.keys() + if len(left_over) > 0: + print("left over keys in unet:", left_over) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) +def load_unet(unet_path): + sd = comfy.utils.load_torch_file(unet_path) + model = load_unet_state_dict(sd) + if model is None: + print("ERROR UNSUPPORTED UNET", unet_path) + raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + return model + def save_checkpoint(output_path, model, clip, vae, metadata=None): model_management.load_models_gpu([model, clip.load_model()]) sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 9978b6c35..58acb97fc 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -8,34 +8,56 @@ import zipfile from . import model_management import contextlib +def gen_empty_tokens(special_tokens, length): + start_token = special_tokens.get("start", None) + end_token = special_tokens.get("end", None) + pad_token = special_tokens.get("pad") + output = [] + if start_token is not None: + output.append(start_token) + if end_token is not None: + output.append(end_token) + output += [pad_token] * (length - len(output)) + return output + class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): - to_encode = list(self.empty_tokens) + to_encode = list() + max_token_len = 0 + has_weights = False for x in token_weight_pairs: tokens = list(map(lambda a: a[0], x)) + max_token_len = max(len(tokens), max_token_len) + has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) to_encode.append(tokens) + sections = len(to_encode) + if has_weights or sections == 0: + to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) + out, pooled = self.encode(to_encode) - z_empty = out[0:1] - if pooled.shape[0] > 1: - first_pooled = pooled[1:2] + if pooled is not None: + first_pooled = pooled[0:1].cpu() else: - first_pooled = pooled[0:1] + first_pooled = pooled output = [] - for k in range(1, out.shape[0]): + for k in range(0, sections): z = out[k:k+1] - for i in range(len(z)): - for j in range(len(z[i])): - weight = token_weight_pairs[k - 1][j][1] - z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] + if has_weights: + z_empty = out[-1] + for i in range(len(z)): + for j in range(len(z[i])): + weight = token_weight_pairs[k][j][1] + if weight != 1.0: + z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] output.append(z) if (len(output) == 0): - return z_empty.cpu(), first_pooled.cpu() - return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() + return out[-1:].cpu(), first_pooled + return torch.cat(output, dim=-2).cpu(), first_pooled -class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", @@ -43,37 +65,43 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): "hidden" ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, - freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32 + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None, + special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig, + model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.num_layers = 12 if textmodel_path is not None: - self.transformer = CLIPTextModel.from_pretrained(textmodel_path) + self.transformer = model_class.from_pretrained(textmodel_path) else: if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") - config = CLIPTextConfig.from_json_file(textmodel_json_config) + config = config_class.from_json_file(textmodel_json_config) self.num_layers = config.num_hidden_layers with comfy.ops.use_comfy_ops(device, dtype): with modeling_utils.no_init_weights(): - self.transformer = CLIPTextModel(config) + self.transformer = model_class(config) + self.inner_name = inner_name if dtype is not None: self.transformer.to(dtype) - self.transformer.text_model.embeddings.token_embedding.to(torch.float32) - self.transformer.text_model.embeddings.position_embedding.to(torch.float32) + inner_model = getattr(self.transformer, self.inner_name) + if hasattr(inner_model, "embeddings"): + inner_model.embeddings.to(torch.float32) + else: + self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) self.max_length = max_length if freeze: self.freeze() self.layer = layer self.layer_idx = None - self.empty_tokens = [[49406] + [49407] * 76] + self.special_tokens = special_tokens self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.enable_attention_masks = False - self.layer_norm_hidden_state = True + self.layer_norm_hidden_state = layer_norm_hidden_state if layer == "hidden": assert layer_idx is not None assert abs(layer_idx) <= self.num_layers @@ -117,7 +145,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) while len(tokens_temp) < len(x): - tokens_temp += [self.empty_tokens[0][-1]] + tokens_temp += [self.special_tokens["pad"]] out_tokens += [tokens_temp] n = token_dict_size @@ -142,12 +170,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: + if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: precision_scope = torch.autocast else: - precision_scope = lambda a, b: contextlib.nullcontext(a) + precision_scope = lambda a, dtype: contextlib.nullcontext(a) - with precision_scope(model_management.get_autocast_device(device), torch.float32): + with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): attention_mask = None if self.enable_attention_masks: attention_mask = torch.zeros_like(tokens) @@ -168,12 +196,16 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: z = outputs.hidden_states[self.layer_idx] if self.layer_norm_hidden_state: - z = self.transformer.text_model.final_layer_norm(z) + z = getattr(self.transformer, self.inner_name).final_layer_norm(z) - pooled_output = outputs.pooler_output - if self.text_projection is not None: + if hasattr(outputs, "pooler_output"): + pooled_output = outputs.pooler_output.float() + else: + pooled_output = None + + if self.text_projection is not None and pooled_output is not None: pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() - return z.float(), pooled_output.float() + return z.float(), pooled_output def encode(self, tokens): return self(tokens) @@ -278,7 +310,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No valid_file = None for embed_dir in embedding_directory: - embed_path = os.path.join(embed_dir, embedding_name) + embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name)) + embed_dir = os.path.abspath(embed_dir) + try: + if os.path.commonpath((embed_dir, embed_path)) != embed_dir: + continue + except: + continue if not os.path.isfile(embed_path): extensions = ['.safetensors', '.pt', '.bin'] for x in extensions: @@ -336,18 +374,25 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No embed_out = next(iter(values)) return embed_out -class SD1Tokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'): +class SDTokenizer: + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") - self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) + self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length - self.max_tokens_per_section = self.max_length - 2 empty = self.tokenizer('')["input_ids"] - self.start_token = empty[0] - self.end_token = empty[1] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() self.inv_vocab = {v: k for k, v in vocab.items()} self.embedding_directory = embedding_directory @@ -408,11 +453,13 @@ class SD1Tokenizer: else: continue #parse word - tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) + tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) #reshape token array to CLIP input size batched_tokens = [] - batch = [(self.start_token, 1.0, 0)] + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0, 0)) batched_tokens.append(batch) for i, t_group in enumerate(tokens): #determine if we're going to try and keep the tokens in a single batch @@ -429,16 +476,21 @@ class SD1Tokenizer: #add end token and pad else: batch.append((self.end_token, 1.0, 0)) - batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) #start new batch - batch = [(self.start_token, 1.0, 0)] + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0, 0)) batched_tokens.append(batch) else: batch.extend([(t,w,i+1) for t,w in t_group]) t_group = [] #fill last batch - batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) + batch.append((self.end_token, 1.0, 0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] @@ -448,3 +500,40 @@ class SD1Tokenizer: def untokenize(self, token_weight_pair): return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) + + +class SD1Tokenizer: + def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer): + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) + + def tokenize_with_weights(self, text:str, return_word_ids=False): + out = {} + out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) + return out + + def untokenize(self, token_weight_pair): + return getattr(self, self.clip).untokenize(token_weight_pair) + + +class SD1ClipModel(torch.nn.Module): + def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs): + super().__init__() + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) + + def clip_layer(self, layer_idx): + getattr(self, self.clip).clip_layer(layer_idx) + + def reset_clip_layer(self): + getattr(self, self.clip).reset_clip_layer() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs[self.clip_name] + out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) + return out, pooled + + def load_sd(self, sd): + return getattr(self, self.clip).load_sd(sd) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 05e50a005..2ee0ca055 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -2,16 +2,23 @@ from comfy import sd1_clip import torch import os -class SD2ClipModel(sd1_clip.SD1ClipModel): +class SD2ClipHModel(sd1_clip.SDClipModel): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" layer_idx=23 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) - self.empty_tokens = [[49406] + [49407] + [0] * 75] + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) -class SD2Tokenizer(sd1_clip.SD1Tokenizer): +class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) + +class SD2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer) + +class SD2ClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs) diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index e3ac2ee0b..673399e22 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -2,28 +2,27 @@ from comfy import sd1_clip import torch import os -class SDXLClipG(sd1_clip.SD1ClipModel): +class SDXLClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) - self.empty_tokens = [[49406] + [49407] + [0] * 75] - self.layer_norm_hidden_state = False + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) def load_sd(self, sd): return super().load_sd(sd) -class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): +class SDXLClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') -class SDXLTokenizer(sd1_clip.SD1Tokenizer): +class SDXLTokenizer: def __init__(self, embedding_directory=None): - self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False): @@ -38,8 +37,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer): class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) - self.clip_l.layer_norm_hidden_state = False + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) def clip_layer(self, layer_idx): @@ -63,21 +61,6 @@ class SDXLClipModel(torch.nn.Module): else: return self.clip_l.load_sd(sd) -class SDXLRefinerClipModel(torch.nn.Module): +class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): - super().__init__() - self.clip_g = SDXLClipG(device=device, dtype=dtype) - - def clip_layer(self, layer_idx): - self.clip_g.clip_layer(layer_idx) - - def reset_clip_layer(self): - self.clip_g.reset_clip_layer() - - def encode_token_weights(self, token_weight_pairs): - token_weight_pairs_g = token_weight_pairs["g"] - g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) - return g_out, g_pooled - - def load_sd(self, sd): - return self.clip_g.load_sd(sd) + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index bb8ae2148..7e2ac677d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE): "model_channels": 320, "use_linear_in_transformer": False, "adm_in_channels": None, + "use_temporal_attention": False, } unet_extra_config = { @@ -38,8 +39,15 @@ class SD15(supported_models_base.BASE): if ids.dtype == torch.float32: state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() + replace_prefix = {} + replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {"clip_l.": "cond_stage_model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def clip_target(self): return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel) @@ -49,6 +57,7 @@ class SD20(supported_models_base.BASE): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": None, + "use_temporal_attention": False, } latent_format = latent_formats.SD15 @@ -62,12 +71,12 @@ class SD20(supported_models_base.BASE): return model_base.ModelType.EPS def process_clip_state_dict(self, state_dict): - state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) return state_dict def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} - replace_prefix[""] = "cond_stage_model.model." + replace_prefix["clip_h"] = "cond_stage_model.model" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict @@ -81,6 +90,7 @@ class SD21UnclipL(SD20): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": 1536, + "use_temporal_attention": False, } clip_vision_prefix = "embedder.model.visual." @@ -93,6 +103,7 @@ class SD21UnclipH(SD20): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": 2048, + "use_temporal_attention": False, } clip_vision_prefix = "embedder.model.visual." @@ -104,7 +115,8 @@ class SDXLRefiner(supported_models_base.BASE): "use_linear_in_transformer": True, "context_dim": 1280, "adm_in_channels": 2560, - "transformer_depth": [0, 4, 4, 0], + "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0], + "use_temporal_attention": False, } latent_format = latent_formats.SDXL @@ -139,9 +151,10 @@ class SDXL(supported_models_base.BASE): unet_config = { "model_channels": 320, "use_linear_in_transformer": True, - "transformer_depth": [0, 2, 10], + "transformer_depth": [0, 0, 2, 2, 10, 10], "context_dim": 2048, - "adm_in_channels": 2816 + "adm_in_channels": 2816, + "use_temporal_attention": False, } latent_format = latent_formats.SDXL @@ -165,6 +178,7 @@ class SDXL(supported_models_base.BASE): replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" + keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -189,5 +203,40 @@ class SDXL(supported_models_base.BASE): def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) +class SSD1B(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 2, 2, 4, 4], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } -models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL] +class SVD_img2vid(supported_models_base.BASE): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 768, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual." + + latent_format = latent_formats.SD15 + + sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002} + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SVD_img2vid(self, device=device) + return out + + def clip_target(self): + return None + +models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] +models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 88a1d7fde..3412cfea0 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -19,7 +19,7 @@ class BASE: clip_prefix = [] clip_vision_prefix = None noise_aug_config = None - beta_schedule = "linear" + sampling_settings = {} latent_format = latent_formats.LatentFormat @classmethod @@ -53,6 +53,12 @@ class BASE: def process_clip_state_dict(self, state_dict): return state_dict + def process_unet_state_dict(self, state_dict): + return state_dict + + def process_vae_state_dict(self, state_dict): + return state_dict + def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 8df1f1609..46f3097a2 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -46,15 +46,16 @@ class TAESD(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"): + def __init__(self, encoder_path=None, decoder_path=None): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.encoder = Encoder() - self.decoder = Decoder() + self.taesd_encoder = Encoder() + self.taesd_decoder = Decoder() + self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) if encoder_path is not None: - self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) + self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) if decoder_path is not None: - self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) + self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) @staticmethod def scale_latents(x): @@ -65,3 +66,11 @@ class TAESD(nn.Module): def unscale_latents(x): """[0, 1] -> raw latents""" return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + + def decode(self, x): + x_sample = self.taesd_decoder(x * self.vae_scale) + x_sample = x_sample.sub(0.5).mul(2) + return x_sample + + def encode(self, x): + return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale diff --git a/comfy/utils.py b/comfy/utils.py index a1807aa1d..294bbb425 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -170,25 +170,12 @@ UNET_MAP_BASIC = { def unet_to_diffusers(unet_config): num_res_blocks = unet_config["num_res_blocks"] - attention_resolutions = unet_config["attention_resolutions"] channel_mult = unet_config["channel_mult"] - transformer_depth = unet_config["transformer_depth"] + transformer_depth = unet_config["transformer_depth"][:] + transformer_depth_output = unet_config["transformer_depth_output"][:] num_blocks = len(channel_mult) - if isinstance(num_res_blocks, int): - num_res_blocks = [num_res_blocks] * num_blocks - if isinstance(transformer_depth, int): - transformer_depth = [transformer_depth] * num_blocks - transformers_per_layer = [] - res = 1 - for i in range(num_blocks): - transformers = 0 - if res in attention_resolutions: - transformers = transformer_depth[i] - transformers_per_layer.append(transformers) - res *= 2 - - transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1]) + transformers_mid = unet_config.get("transformer_depth_middle", None) diffusers_unet_map = {} for x in range(num_blocks): @@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config): for i in range(num_res_blocks[x]): for b in UNET_MAP_RESNET: diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) - if transformers_per_layer[x] > 0: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) - for t in range(transformers_per_layer[x]): + for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) n += 1 @@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config): diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) num_res_blocks = list(reversed(num_res_blocks)) - transformers_per_layer = list(reversed(transformers_per_layer)) for x in range(num_blocks): n = (num_res_blocks[x] + 1) * x l = num_res_blocks[x] + 1 @@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config): for b in UNET_MAP_RESNET: diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) c += 1 - if transformers_per_layer[x] > 0: + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: c += 1 for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) - for t in range(transformers_per_layer[x]): + for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) if i == l - 1: @@ -270,9 +258,17 @@ def set_attr(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value)) + setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) del prev +def copy_to_param(obj, attr, value): + # inplace update tensor instead of replacing it + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + prev.data.copy_(value) + def get_attr(obj, attr): attrs = attr.split(".") for name in attrs: @@ -311,23 +307,25 @@ def bislerp(samples, width, height): res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] return res - def generate_bilinear_data(length_old, length_new): - coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + def generate_bilinear_data(length_old, length_new, device): + coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") ratios = coords_1 - coords_1.floor() coords_1 = coords_1.to(torch.int64) - coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1 coords_2[:,:,:,-1] -= 1 coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = coords_2.to(torch.int64) return ratios, coords_1, coords_2 - + + orig_dtype = samples.dtype + samples = samples.float() n,c,h,w = samples.shape h_new, w_new = (height, width) #linear w - ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) coords_1 = coords_1.expand((n, c, h, -1)) coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) @@ -340,7 +338,7 @@ def bislerp(samples, width, height): result = result.reshape(n, h, w_new, c).movedim(-1, 1) #linear h - ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device) coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) @@ -351,7 +349,7 @@ def bislerp(samples, width, height): result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) - return result + return result.to(orig_dtype) def lanczos(samples, width, height): images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index b52ad8fbd..008d0b8d6 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -16,7 +16,7 @@ class BasicScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -36,7 +36,7 @@ class KarrasScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -54,7 +54,7 @@ class ExponentialScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -73,7 +73,7 @@ class PolyexponentialScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -81,6 +81,25 @@ class PolyexponentialScheduler: sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) return (sigmas, ) +class SDTurboScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model, steps): + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] + sigmas = model.model.model_sampling.sigma(timesteps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + return (sigmas, ) + class VPScheduler: @classmethod def INPUT_TYPES(s): @@ -92,7 +111,7 @@ class VPScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -109,7 +128,7 @@ class SplitSigmas: } } RETURN_TYPES = ("SIGMAS","SIGMAS") - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/sigmas" FUNCTION = "get_sigmas" @@ -118,6 +137,24 @@ class SplitSigmas: sigmas2 = sigmas[step:] return (sigmas1, sigmas2) +class FlipSigmas: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sigmas": ("SIGMAS", ), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/sigmas" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, sigmas): + sigmas = sigmas.flip(0) + if sigmas[0] == 0: + sigmas[0] = 0.0001 + return (sigmas,) + class KSamplerSelect: @classmethod def INPUT_TYPES(s): @@ -126,12 +163,12 @@ class KSamplerSelect: } } RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" def get_sampler(self, sampler_name): - sampler = comfy.samplers.sampler_class(sampler_name)() + sampler = comfy.samplers.sampler_object(sampler_name) return (sampler, ) class SamplerDPMPP_2M_SDE: @@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE: } } RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" @@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE: sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" - sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})() + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) return (sampler, ) @@ -169,7 +206,7 @@ class SamplerDPMPP_SDE: } } RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" @@ -178,7 +215,7 @@ class SamplerDPMPP_SDE: sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" - sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})() + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) return (sampler, ) class SamplerCustom: @@ -188,7 +225,7 @@ class SamplerCustom: {"model": ("MODEL",), "add_noise": ("BOOLEAN", {"default": True}), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "sampler": ("SAMPLER", ), @@ -234,13 +271,15 @@ class SamplerCustom: NODE_CLASS_MAPPINGS = { "SamplerCustom": SamplerCustom, + "BasicScheduler": BasicScheduler, "KarrasScheduler": KarrasScheduler, "ExponentialScheduler": ExponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler, "VPScheduler": VPScheduler, + "SDTurboScheduler": SDTurboScheduler, "KSamplerSelect": KSamplerSelect, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, - "BasicScheduler": BasicScheduler, "SplitSigmas": SplitSigmas, + "FlipSigmas": FlipSigmas, } diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py new file mode 100644 index 000000000..5ad2235a5 --- /dev/null +++ b/comfy_extras/nodes_images.py @@ -0,0 +1,175 @@ +import nodes +import folder_paths +from comfy.cli_args import args + +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import numpy as np +import json +import os + +MAX_RESOLUTION = nodes.MAX_RESOLUTION + +class ImageCrop: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "crop" + + CATEGORY = "image/transform" + + def crop(self, image, width, height, x, y): + x = min(x, image.shape[2] - 1) + y = min(y, image.shape[1] - 1) + to_x = width + x + to_y = height + y + img = image[:,y:to_y, x:to_x, :] + return (img,) + +class RepeatImageBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "repeat" + + CATEGORY = "image/batch" + + def repeat(self, image, amount): + s = image.repeat((amount, 1,1,1)) + return (s,) + +class SaveAnimatedWEBP: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + methods = {"default": 4, "fastest": 0, "slowest": 6} + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "lossless": ("BOOLEAN", {"default": True}), + "quality": ("INT", {"default": 80, "min": 0, "max": 100}), + "method": (list(s.methods.keys()),), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): + method = self.methods.get(method) + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + pil_images = [] + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pil_images.append(img) + + metadata = pil_images[0].getexif() + if not args.disable_metadata: + if prompt is not None: + metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) + if extra_pnginfo is not None: + inital_exif = 0x010f + for x in extra_pnginfo: + metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) + inital_exif -= 1 + + if num_frames == 0: + num_frames = len(pil_images) + + c = len(pil_images) + for i in range(0, c, num_frames): + file = f"{filename}_{counter:05}_.webp" + pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + + animated = num_frames != 1 + return { "ui": { "images": results, "animated": (animated,) } } + +class SaveAnimatedPNG: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + pil_images = [] + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pil_images.append(img) + + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) + + file = f"{filename}_{counter:05}_.png" + pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + + return { "ui": { "images": results, "animated": (True,)} } + +NODE_CLASS_MAPPINGS = { + "ImageCrop": ImageCrop, + "RepeatImageBatch": RepeatImageBatch, + "SaveAnimatedWEBP": SaveAnimatedWEBP, + "SaveAnimatedPNG": SaveAnimatedPNG, +} diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 001de39fc..cedf39d63 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -1,4 +1,5 @@ import comfy.utils +import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: @@ -67,8 +68,43 @@ class LatentMultiply: samples_out["samples"] = s1 * multiplier return (samples_out,) +class LatentInterpolate: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), + "samples2": ("LATENT",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2, ratio): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + + m1 = torch.linalg.vector_norm(s1, dim=(1)) + m2 = torch.linalg.vector_norm(s2, dim=(1)) + + s1 = torch.nan_to_num(s1 / m1) + s2 = torch.nan_to_num(s2 / m2) + + t = (s1 * ratio + s2 * (1.0 - ratio)) + mt = torch.linalg.vector_norm(t, dim=(1)) + st = torch.nan_to_num(t / mt) + + samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, + "LatentInterpolate": LatentInterpolate, } diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py new file mode 100644 index 000000000..efcdf1932 --- /dev/null +++ b/comfy_extras/nodes_model_advanced.py @@ -0,0 +1,205 @@ +import folder_paths +import comfy.sd +import comfy.model_sampling +import torch + +class LCM(comfy.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + x0 = model_input - model_output * sigma + + sigma_data = 0.5 + scaled_timestep = timestep * 10.0 #timestep_scaling + + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + + return c_out * x0 + c_skip * model_input + +class ModelSamplingDiscreteDistilled(torch.nn.Module): + original_timesteps = 50 + + def __init__(self): + super().__init__() + self.sigma_data = 1.0 + timesteps = 1000 + beta_start = 0.00085 + beta_end = 0.012 + + betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + self.skip_steps = timesteps // self.original_timesteps + + + alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + for x in range(self.original_timesteps): + alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + + sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 + self.set_sigmas(sigmas) + + def set_sigmas(self, sigmas): + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device) + + def sigma(self, timestep): + t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp().to(timestep.device) + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + percent = 1.0 - percent + return self.sigma(torch.tensor(percent * 999.0)).item() + + +def rescale_zero_terminal_snr_sigmas(sigmas): + alphas_cumprod = 1 / ((sigmas * sigmas) + 1) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return ((1 - alphas_bar) / alphas_bar) ** 0.5 + +class ModelSamplingDiscrete: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["eps", "v_prediction", "lcm"],), + "zsnr": ("BOOLEAN", {"default": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, zsnr): + m = model.clone() + + sampling_base = comfy.model_sampling.ModelSamplingDiscrete + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "lcm": + sampling_type = LCM + sampling_base = ModelSamplingDiscreteDistilled + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced() + if zsnr: + model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) + + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + +class ModelSamplingContinuousEDM: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["v_prediction", "eps"],), + "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, sigma_max, sigma_min): + m = model.clone() + + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced() + model_sampling.set_sigma_range(sigma_min, sigma_max) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + +class RescaleCFG: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, multiplier): + def rescale_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + sigma = args["sigma"] + sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) + x_orig = args["input"] + + #rescale cfg has to be done on v-pred model output + x = x_orig / (sigma * sigma + 1.0) + cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + + #rescalecfg + x_cfg = uncond + cond_scale * (cond - uncond) + ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True) + ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True) + + x_rescaled = x_cfg * (ro_pos / ro_cfg) + x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg + + return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5) + + m = model.clone() + m.set_model_sampler_cfg_function(rescale_cfg) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ModelSamplingDiscrete": ModelSamplingDiscrete, + "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "RescaleCFG": RescaleCFG, +} diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py new file mode 100644 index 000000000..48bcc6892 --- /dev/null +++ b/comfy_extras/nodes_model_downscale.py @@ -0,0 +1,53 @@ +import torch +import comfy.utils + +class PatchModelAddDownscale: + upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), + "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), + "downscale_after_skip": ("BOOLEAN", {"default": True}), + "downscale_method": (s.upscale_methods,), + "upscale_method": (s.upscale_methods,), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method): + sigma_start = model.model.model_sampling.percent_to_sigma(start_percent) + sigma_end = model.model.model_sampling.percent_to_sigma(end_percent) + + def input_block_patch(h, transformer_options): + if transformer_options["block"][1] == block_number: + sigma = transformer_options["sigmas"][0].item() + if sigma <= sigma_start and sigma >= sigma_end: + h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled") + return h + + def output_block_patch(h, hsp, transformer_options): + if h.shape[2] != hsp.shape[2]: + h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") + return h, hsp + + m = model.clone() + if downscale_after_skip: + m.set_model_input_block_patch_after_skip(input_block_patch) + else: + m.set_model_input_block_patch(input_block_patch) + m.set_model_output_block_patch(output_block_patch) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "PatchModelAddDownscale": PatchModelAddDownscale, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + # Sampling + "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", +} diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 3f651e594..12704f545 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -23,7 +23,7 @@ class Blend: "max": 1.0, "step": 0.01 }), - "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],), + "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],), }, } @@ -54,6 +54,8 @@ class Blend: return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) elif mode == "soft_light": return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + elif mode == "difference": + return img1 - img2 else: raise ValueError(f"Unsupported blend mode: {mode}") @@ -126,7 +128,7 @@ class Quantize: "max": 256, "step": 1 }), - "dither": (["none", "floyd-steinberg"],), + "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],), }, } @@ -135,19 +137,47 @@ class Quantize: CATEGORY = "image/postprocessing" - def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): + def bayer(im, pal_im, order): + def normalized_bayer_matrix(n): + if n == 0: + return np.zeros((1,1), "float32") + else: + q = 4 ** n + m = q * normalized_bayer_matrix(n - 1) + return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q + + num_colors = len(pal_im.getpalette()) // 3 + spread = 2 * 256 / num_colors + bayer_n = int(math.log2(order)) + bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5) + + result = torch.from_numpy(np.array(im).astype(np.float32)) + tw = math.ceil(result.shape[0] / bayer_matrix.shape[0]) + th = math.ceil(result.shape[1] / bayer_matrix.shape[1]) + tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1) + result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255) + result = result.to(dtype=torch.uint8) + + im = Image.fromarray(result.cpu().numpy()) + im = im.quantize(palette=pal_im, dither=Image.Dither.NONE) + return im + + def quantize(self, image: torch.Tensor, colors: int, dither: str): batch_size, height, width, _ = image.shape result = torch.zeros_like(image) - dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE - for b in range(batch_size): - tensor_image = image[b] - img = (tensor_image * 255).to(torch.uint8).numpy() - pil_image = Image.fromarray(img, mode='RGB') + im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB') - palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 - quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option) + pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 + + if dither == "none": + quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE) + elif dither == "floyd-steinberg": + quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG) + elif dither.startswith("bayer"): + order = int(dither.split('-')[-1]) + quantized_image = Quantize.bayer(im, pal_im, order) quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 result[b] = quantized_array diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 0a9daf272..88a4ebe29 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -4,7 +4,7 @@ class LatentRebatch: @classmethod def INPUT_TYPES(s): return {"required": { "latents": ("LATENT",), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("LATENT",) INPUT_IS_LIST = True diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py new file mode 100644 index 000000000..26a717a38 --- /dev/null +++ b/comfy_extras/nodes_video_model.py @@ -0,0 +1,89 @@ +import nodes +import torch +import comfy.utils +import comfy.sd +import folder_paths + + +class ImageOnlyCheckpointLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + }} + RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "loaders/video_models" + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (out[0], out[3], out[2]) + + +class SVD_img2vid_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}), + "motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}), + "fps": ("INT", {"default": 6, "min": 1, "max": 1024}), + "augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}) + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + if augmentation_level > 0: + encode_pixels += torch.randn_like(pixels) * augmentation_level + t = vae.encode(encode_pixels) + positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + +class VideoLinearCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + + scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1)) + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, + "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, + "VideoLinearCFGGuidance": VideoLinearCFGGuidance, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)", +} diff --git a/execution.py b/execution.py index 2ebae8ba0..f7316b015 100644 --- a/execution.py +++ b/execution.py @@ -683,6 +683,7 @@ def validate_prompt(prompt): return (True, None, list(good_outputs), node_errors) +MAXIMUM_HISTORY_SIZE = 10000 class PromptQueue: def __init__(self, server): @@ -715,6 +716,8 @@ class PromptQueue: def task_done(self, item_id, outputs): with self.mutex: prompt = self.currently_running.pop(item_id) + if len(self.history) > MAXIMUM_HISTORY_SIZE: + self.history.pop(next(iter(self.history))) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } for o in outputs: self.history[prompt[1]]["outputs"][o] = outputs[o] @@ -749,10 +752,20 @@ class PromptQueue: return True return False - def get_history(self, prompt_id=None): + def get_history(self, prompt_id=None, max_items=None, offset=-1): with self.mutex: if prompt_id is None: - return copy.deepcopy(self.history) + out = {} + i = 0 + if offset < 0 and max_items is not None: + offset = len(self.history) - max_items + for k in self.history: + if i >= offset: + out[k] = self.history[k] + if max_items is not None and len(out) >= max_items: + break + i += 1 + return out elif prompt_id in self.history: return {prompt_id: copy.deepcopy(self.history[prompt_id])} else: diff --git a/folder_paths.py b/folder_paths.py index 64964bd87..482cc2222 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -38,7 +38,10 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "inp filename_list_cache = {} if not os.path.exists(input_directory): - os.makedirs(input_directory) + try: + os.makedirs(input_directory) + except: + print("Failed to create input directory") def set_output_directory(output_dir): global output_directory @@ -227,8 +230,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height full_output_folder = os.path.join(output_dir, subfolder) if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: - print("Saving image outside the output folder is not allowed.") - return {} + err = "**** ERROR: Saving image outside the output folder is not allowed." + \ + "\n full_output_folder: " + os.path.abspath(full_output_folder) + \ + "\n output_dir: " + output_dir + \ + "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) + print(err) + raise Exception(err) try: counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 diff --git a/latent_preview.py b/latent_preview.py index e1553c85c..61754751e 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer): self.taesd = taesd def decode_latent_to_preview(self, x0): - x_sample = self.taesd.decoder(x0)[0].detach() - # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] - x_sample = x_sample.sub(0.5).mul(2) - + x_sample = self.taesd.decode(x0[:1])[0].detach() x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) diff --git a/main.py b/main.py index 1100a07f4..3997fbefc 100644 --- a/main.py +++ b/main.py @@ -88,6 +88,7 @@ def cuda_malloc_warning(): def prompt_worker(q, server): e = execution.PromptExecutor(server) + last_gc_collect = 0 while True: item, item_id = q.get() execution_start_time = time.perf_counter() @@ -97,9 +98,14 @@ def prompt_worker(q, server): if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) - print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) - gc.collect() - comfy.model_management.soft_empty_cache() + current_time = time.perf_counter() + execution_time = current_time - execution_start_time + print("Prompt executed in {:.2f} seconds".format(execution_time)) + if (current_time - last_gc_collect) > 10.0: + gc.collect() + comfy.model_management.soft_empty_cache() + last_gc_collect = current_time + print("gc collect") async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) diff --git a/nodes.py b/nodes.py index 97012413d..eb9ee3b34 100644 --- a/nodes.py +++ b/nodes.py @@ -248,8 +248,8 @@ class ConditioningSetTimestepRange: c = [] for t in conditioning: d = t[1].copy() - d['start_percent'] = 1.0 - start - d['end_percent'] = 1.0 - end + d['start_percent'] = start + d['end_percent'] = end n = [t[0], d] c.append(n) return (c, ) @@ -572,10 +572,69 @@ class LoraLoader: model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) return (model_lora, clip_lora) -class VAELoader: +class LoraLoaderModelOnly(LoraLoader): @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}} + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + +class VAELoader: + @staticmethod + def vae_list(): + vaes = folder_paths.get_filename_list("vae") + approx_vaes = folder_paths.get_filename_list("vae_approx") + sdxl_taesd_enc = False + sdxl_taesd_dec = False + sd1_taesd_enc = False + sd1_taesd_dec = False + + for v in approx_vaes: + if v.startswith("taesd_decoder."): + sd1_taesd_dec = True + elif v.startswith("taesd_encoder."): + sd1_taesd_enc = True + elif v.startswith("taesdxl_decoder."): + sdxl_taesd_dec = True + elif v.startswith("taesdxl_encoder."): + sdxl_taesd_enc = True + if sd1_taesd_dec and sd1_taesd_enc: + vaes.append("taesd") + if sdxl_taesd_dec and sdxl_taesd_enc: + vaes.append("taesdxl") + return vaes + + @staticmethod + def load_taesd(name): + sd = {} + approx_vaes = folder_paths.get_filename_list("vae_approx") + + encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) + decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) + + enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder)) + for k in enc: + sd["taesd_encoder.{}".format(k)] = enc[k] + + dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder)) + for k in dec: + sd["taesd_decoder.{}".format(k)] = dec[k] + + if name == "taesd": + sd["vae_scale"] = torch.tensor(0.18215) + elif name == "taesdxl": + sd["vae_scale"] = torch.tensor(0.13025) + return sd + + @classmethod + def INPUT_TYPES(s): + return {"required": { "vae_name": (s.vae_list(), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -583,8 +642,11 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - vae_path = folder_paths.get_full_path("vae", vae_name) - sd = comfy.utils.load_torch_file(vae_path) + if vae_name in ["taesd", "taesdxl"]: + sd = self.load_taesd(vae_name) + else: + vae_path = folder_paths.get_full_path("vae", vae_name) + sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) return (vae,) @@ -685,7 +747,7 @@ class ControlNetApplyAdvanced: if prev_cnet in cnets: c_net = cnets[prev_cnet] else: - c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent)) + c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent)) c_net.set_previous_controlnet(prev_cnet) cnets[prev_cnet] = c_net @@ -1218,7 +1280,7 @@ class KSampler: {"model": ("MODEL",), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ), @@ -1244,7 +1306,7 @@ class KSamplerAdvanced: "add_noise": (["enable", "disable"], ), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ), @@ -1275,6 +1337,7 @@ class SaveImage: self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" + self.compress_level = 4 @classmethod def INPUT_TYPES(s): @@ -1308,7 +1371,7 @@ class SaveImage: metadata.add_text(x, json.dumps(extra_pnginfo[x])) file = f"{filename}_{counter:05}_.png" - img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) + img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) results.append({ "filename": file, "subfolder": subfolder, @@ -1323,6 +1386,7 @@ class PreviewImage(SaveImage): self.output_dir = folder_paths.get_temp_directory() self.type = "temp" self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) + self.compress_level = 1 @classmethod def INPUT_TYPES(s): @@ -1654,6 +1718,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningZeroOut": ConditioningZeroOut, "ConditioningSetTimestepRange": ConditioningSetTimestepRange, + "LoraLoaderModelOnly": LoraLoaderModelOnly, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1759,7 +1824,7 @@ def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] for custom_node_path in node_paths: - possible_modules = os.listdir(custom_node_path) + possible_modules = os.listdir(os.path.realpath(custom_node_path)) if "__pycache__" in possible_modules: possible_modules.remove("__pycache__") @@ -1799,6 +1864,10 @@ def init_custom_nodes(): "nodes_custom_sampler.py", "nodes_subflow.py", "nodes_hypertile.py", + "nodes_model_advanced.py", + "nodes_model_downscale.py", + "nodes_images.py", + "nodes_video_model.py", ] for node_file in extras_files: diff --git a/server.py b/server.py index bde01e0dc..0e16eee27 100644 --- a/server.py +++ b/server.py @@ -83,7 +83,8 @@ class PromptServer(): if args.enable_cors_header: middlewares.append(create_cors_middleware(args.enable_cors_header)) - self.app = web.Application(client_max_size=104857600, middlewares=middlewares) + max_upload_size = round(args.max_upload_size * 1024 * 1024) + self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") @@ -433,7 +434,10 @@ class PromptServer(): @routes.get("/history") async def get_history(request): - return web.json_response(self.prompt_queue.get_history()) + max_items = request.rel_url.query.get("max_items", None) + if max_items is not None: + max_items = int(max_items) + return web.json_response(self.prompt_queue.get_history(max_items=max_items)) @routes.get("/history/{prompt_id}") async def get_history(request): @@ -590,7 +594,7 @@ class PromptServer(): bytesIO = BytesIO() header = struct.pack(">I", type_num) bytesIO.write(header) - image.save(bytesIO, format=image_type, quality=95, compress_level=4) + image.save(bytesIO, format=image_type, quality=95, compress_level=1) preview_bytes = bytesIO.getvalue() await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index 022e54926..e1873105a 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -14,10 +14,10 @@ const lg = require("../utils/litegraph"); * @param { InstanceType } graph * @param { InstanceType } input * @param { string } widgetType - * @param { boolean } hasControlWidget + * @param { number } controlWidgetCount * @returns */ -async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) { +async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) { // Connect to primitive and ensure its still connected after let primitive = ez.PrimitiveNode(); primitive.outputs[0].connectTo(input); @@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro expect(valueWidget.widget.type).toBe(widgetType); // Check if control_after_generate should be added - if (hasControlWidget) { + if (controlWidgetCount) { const controlWidget = primitive.widgets.control_after_generate; expect(controlWidget.widget.type).toBe("combo"); + if(widgetType === "combo") { + const filterWidget = primitive.widgets.control_filter_list; + expect(filterWidget.widget.type).toBe("string"); + } } // Ensure we dont have other widgets - expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget); + expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount); }); return primitive; @@ -55,8 +59,8 @@ describe("widget inputs", () => { }); [ - { name: "int", type: "INT", widget: "number", control: true }, - { name: "float", type: "FLOAT", widget: "number", control: true }, + { name: "int", type: "INT", widget: "number", control: 1 }, + { name: "float", type: "FLOAT", widget: "number", control: 1 }, { name: "text", type: "STRING" }, { name: "customtext", @@ -64,7 +68,7 @@ describe("widget inputs", () => { opt: { multiline: true }, }, { name: "toggle", type: "BOOLEAN" }, - { name: "combo", type: ["a", "b", "c"], control: true }, + { name: "combo", type: ["a", "b", "c"], control: 2 }, ].forEach((c) => { test(`widget conversion + primitive works on ${c.name}`, async () => { const { ez, graph } = await start({ @@ -106,7 +110,7 @@ describe("widget inputs", () => { n.widgets.ckpt_name.convertToInput(); expect(n.inputs.length).toEqual(inputCount + 1); - const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true); + const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2); // Disconnect & reconnect primitive.outputs[0].connections[0].disconnect(); @@ -226,7 +230,7 @@ describe("widget inputs", () => { // Reload and ensure it still only has 1 converted widget if (!assertNotNullOrUndefined(input)) return; - await connectPrimitiveAndReload(ez, graph, input, "number", true); + await connectPrimitiveAndReload(ez, graph, input, "number", 1); n = graph.find(n); expect(n.widgets).toHaveLength(1); w = n.widgets.example; @@ -258,7 +262,7 @@ describe("widget inputs", () => { // Reload and ensure it still only has 1 converted widget if (assertNotNullOrUndefined(input)) { - await connectPrimitiveAndReload(ez, graph, input, "number", true); + await connectPrimitiveAndReload(ez, graph, input, "number", 1); n = graph.find(n); expect(n.widgets).toHaveLength(1); expect(n.widgets.example.isConvertedToInput).toBeTruthy(); @@ -316,4 +320,76 @@ describe("widget inputs", () => { n1.outputs[0].connectTo(n2.inputs[0]); expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow(); }); + + test("combo primitive can filter list when control_after_generate called", async () => { + const { ez } = await start({ + mockNodeDefs: { + ...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }), + }, + }); + + const n1 = ez.TestNode1(); + n1.widgets.example.convertToInput(); + const p = ez.PrimitiveNode() + p.outputs[0].connectTo(n1.inputs[0]); + + const value = p.widgets.value; + const control = p.widgets.control_after_generate.widget; + const filter = p.widgets.control_filter_list; + + expect(p.widgets.length).toBe(3); + control.value = "increment"; + expect(value.value).toBe("A"); + + // Manually trigger after queue when set to increment + control["afterQueued"](); + expect(value.value).toBe("B"); + + // Filter to items containing D + filter.value = "D"; + control["afterQueued"](); + expect(value.value).toBe("D"); + control["afterQueued"](); + expect(value.value).toBe("DD"); + + // Check decrement + value.value = "BBB"; + control.value = "decrement"; + filter.value = "B"; + control["afterQueued"](); + expect(value.value).toBe("BB"); + control["afterQueued"](); + expect(value.value).toBe("B"); + + // Check regex works + value.value = "BBB"; + filter.value = "/[AB]|^C$/"; + control["afterQueued"](); + expect(value.value).toBe("AAA"); + control["afterQueued"](); + expect(value.value).toBe("BB"); + control["afterQueued"](); + expect(value.value).toBe("AA"); + control["afterQueued"](); + expect(value.value).toBe("C"); + control["afterQueued"](); + expect(value.value).toBe("B"); + control["afterQueued"](); + expect(value.value).toBe("A"); + + // Check random + control.value = "randomize"; + filter.value = "/D/"; + for(let i = 0; i < 100; i++) { + control["afterQueued"](); + expect(value.value === "D" || value.value === "DD").toBeTruthy(); + } + + // Ensure it doesnt apply when fixed + control.value = "fixed"; + value.value = "B"; + filter.value = "C"; + control["afterQueued"](); + expect(value.value).toBe("B"); + }); }); diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 3695b08e2..b8d83613d 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -174,6 +174,213 @@ const colorPalettes = { "tr-odd-bg-color": "#073642", } }, + }, + "arc": { + "id": "arc", + "name": "Arc", + "colors": { + "node_slot": { + "BOOLEAN": "", + "CLIP": "#eacb8b", + "CLIP_VISION": "#A8DADC", + "CLIP_VISION_OUTPUT": "#ad7452", + "CONDITIONING": "#cf876f", + "CONTROL_NET": "#00d78d", + "CONTROL_NET_WEIGHTS": "", + "FLOAT": "", + "GLIGEN": "", + "IMAGE": "#80a1c0", + "IMAGEUPLOAD": "", + "INT": "", + "LATENT": "#b38ead", + "LATENT_KEYFRAME": "", + "MASK": "#a3bd8d", + "MODEL": "#8978a7", + "SAMPLER": "", + "SIGMAS": "", + "STRING": "", + "STYLE_MODEL": "#C2FFAE", + "T2I_ADAPTER_WEIGHTS": "", + "TAESD": "#DCC274", + "TIMESTEP_KEYFRAME": "", + "UPSCALE_MODEL": "", + "VAE": "#be616b" + }, + "litegraph_base": { + "BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAYAAABw4pVUAAAACXBIWXMAAAsTAAALEwEAmpwYAAABcklEQVR4nO3YMUoDARgF4RfxBqZI6/0vZqFn0MYtrLIQMFN8U6V4LAtD+Jm9XG/v30OGl2e/AP7yevz4+vx45nvgF/+QGITEICQGITEIiUFIjNNC3q43u3/YnRJyPOzeQ+0e220nhRzReC8e7R7bbdvl+Jal1Bs46jEIiUFIDEJiEBKDkBhKPbZT6qHdptRTu02p53DUYxASg5AYhMQgJAYhMZR6bKfUQ7tNqad2m1LP4ajHICQGITEIiUFIDEJiKPXYTqmHdptST+02pZ7DUY9BSAxCYhASg5AYhMRQ6rGdUg/tNqWe2m1KPYejHoOQGITEICQGITEIiaHUYzulHtptSj2125R6Dkc9BiExCIlBSAxCYhASQ6nHdko9tNuUemq3KfUcjnoMQmIQEoOQGITEICSGUo/tlHpotyn11G5T6jkc9RiExCAkBiExCIlBSAylHtsp9dBuU+qp3abUczjqMQiJQUgMQmIQEoOQGITE+AHFISNQrFTGuwAAAABJRU5ErkJggg==", + "CLEAR_BACKGROUND_COLOR": "#2b2f38", + "NODE_TITLE_COLOR": "#b2b7bd", + "NODE_SELECTED_TITLE_COLOR": "#FFF", + "NODE_TEXT_SIZE": 14, + "NODE_TEXT_COLOR": "#AAA", + "NODE_SUBTEXT_SIZE": 12, + "NODE_DEFAULT_COLOR": "#2b2f38", + "NODE_DEFAULT_BGCOLOR": "#242730", + "NODE_DEFAULT_BOXCOLOR": "#6e7581", + "NODE_DEFAULT_SHAPE": "box", + "NODE_BOX_OUTLINE_COLOR": "#FFF", + "DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)", + "DEFAULT_GROUP_FONT": 22, + "WIDGET_BGCOLOR": "#2b2f38", + "WIDGET_OUTLINE_COLOR": "#6e7581", + "WIDGET_TEXT_COLOR": "#DDD", + "WIDGET_SECONDARY_TEXT_COLOR": "#b2b7bd", + "LINK_COLOR": "#9A9", + "EVENT_LINK_COLOR": "#A86", + "CONNECTING_LINK_COLOR": "#AFA" + }, + "comfy_base": { + "fg-color": "#fff", + "bg-color": "#2b2f38", + "comfy-menu-bg": "#242730", + "comfy-input-bg": "#2b2f38", + "input-text": "#ddd", + "descrip-text": "#b2b7bd", + "drag-text": "#ccc", + "error-text": "#ff4444", + "border-color": "#6e7581", + "tr-even-bg-color": "#2b2f38", + "tr-odd-bg-color": "#242730" + } + }, + }, + "nord": { + "id": "nord", + "name": "Nord", + "colors": { + "node_slot": { + "BOOLEAN": "", + "CLIP": "#eacb8b", + "CLIP_VISION": "#A8DADC", + "CLIP_VISION_OUTPUT": "#ad7452", + "CONDITIONING": "#cf876f", + "CONTROL_NET": "#00d78d", + "CONTROL_NET_WEIGHTS": "", + "FLOAT": "", + "GLIGEN": "", + "IMAGE": "#80a1c0", + "IMAGEUPLOAD": "", + "INT": "", + "LATENT": "#b38ead", + "LATENT_KEYFRAME": "", + "MASK": "#a3bd8d", + "MODEL": "#8978a7", + "SAMPLER": "", + "SIGMAS": "", + "STRING": "", + "STYLE_MODEL": "#C2FFAE", + "T2I_ADAPTER_WEIGHTS": "", + "TAESD": "#DCC274", + "TIMESTEP_KEYFRAME": "", + "UPSCALE_MODEL": "", + "VAE": "#be616b" + }, + "litegraph_base": { + "BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFu2lUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4gPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iQWRvYmUgWE1QIENvcmUgOS4xLWMwMDEgNzkuMTQ2Mjg5OSwgMjAyMy8wNi8yNS0yMDowMTo1NSAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvIiB4bWxuczpkYz0iaHR0cDovL3B1cmwub3JnL2RjL2VsZW1lbnRzLzEuMS8iIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIgeG1sbnM6eG1wTU09Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9tbS8iIHhtbG5zOnN0RXZ0PSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VFdmVudCMiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIiB4bXA6Q3JlYXRlRGF0ZT0iMjAyMy0xMS0xM1QwMDoxODowMiswMTowMCIgeG1wOk1vZGlmeURhdGU9IjIwMjMtMTEtMTVUMDE6MjA6NDUrMDE6MDAiIHhtcDpNZXRhZGF0YURhdGU9IjIwMjMtMTEtMTVUMDE6MjA6NDUrMDE6MDAiIGRjOmZvcm1hdD0iaW1hZ2UvcG5nIiBwaG90b3Nob3A6Q29sb3JNb2RlPSIzIiB4bXBNTTpJbnN0YW5jZUlEPSJ4bXAuaWlkOjUwNDFhMmZjLTEzNzQtMTk0ZC1hZWY4LTYxMzM1MTVmNjUwMCIgeG1wTU06RG9jdW1lbnRJRD0ieG1wLmRpZDoyMzFiMTBiMC1iNGZiLTAyNGUtYjEyZS0zMDUzMDNjZDA3YzgiIHhtcE1NOk9yaWdpbmFsRG9jdW1lbnRJRD0ieG1wLmRpZDoyMzFiMTBiMC1iNGZiLTAyNGUtYjEyZS0zMDUzMDNjZDA3YzgiPiA8eG1wTU06SGlzdG9yeT4gPHJkZjpTZXE+IDxyZGY6bGkgc3RFdnQ6YWN0aW9uPSJjcmVhdGVkIiBzdEV2dDppbnN0YW5jZUlEPSJ4bXAuaWlkOjIzMWIxMGIwLWI0ZmItMDI0ZS1iMTJlLTMwNTMwM2NkMDdjOCIgc3RFdnQ6d2hlbj0iMjAyMy0xMS0xM1QwMDoxODowMiswMTowMCIgc3RFdnQ6c29mdHdhcmVBZ2VudD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIi8+IDxyZGY6bGkgc3RFdnQ6YWN0aW9uPSJzYXZlZCIgc3RFdnQ6aW5zdGFuY2VJRD0ieG1wLmlpZDo1MDQxYTJmYy0xMzc0LTE5NGQtYWVmOC02MTMzNTE1ZjY1MDAiIHN0RXZ0OndoZW49IjIwMjMtMTEtMTVUMDE6MjA6NDUrMDE6MDAiIHN0RXZ0OnNvZnR3YXJlQWdlbnQ9IkFkb2JlIFBob3Rvc2hvcCAyNS4xIChXaW5kb3dzKSIgc3RFdnQ6Y2hhbmdlZD0iLyIvPiA8L3JkZjpTZXE+IDwveG1wTU06SGlzdG9yeT4gPC9yZGY6RGVzY3JpcHRpb24+IDwvcmRmOlJERj4gPC94OnhtcG1ldGE+IDw/eHBhY2tldCBlbmQ9InIiPz73jWg/AAAAyUlEQVR42u3WKwoAIBRFQRdiMb1idv9Lsxn9gEFw4Dbb8JCTojbbXEJwjJVL2HKwYMGCBQuWLbDmjr+9zrBGjHl1WVcvy2DBggULFizTWQpewSt4HzwsgwULFiwFr7MUvMtS8D54WLBgGSxYCl7BK3iXZbBgwYIFC5bpLAWv4BW8Dx6WwYIFC5aC11kK3mUpeB88LFiwDBYsBa/gFbzLMliwYMGCBct0loJX8AreBw/LYMGCBUvB6ywF77IUvA8eFixYBgsWrNfWAZPltufdad+1AAAAAElFTkSuQmCC", + "CLEAR_BACKGROUND_COLOR": "#212732", + "NODE_TITLE_COLOR": "#999", + "NODE_SELECTED_TITLE_COLOR": "#e5eaf0", + "NODE_TEXT_SIZE": 14, + "NODE_TEXT_COLOR": "#bcc2c8", + "NODE_SUBTEXT_SIZE": 12, + "NODE_DEFAULT_COLOR": "#2e3440", + "NODE_DEFAULT_BGCOLOR": "#161b22", + "NODE_DEFAULT_BOXCOLOR": "#545d70", + "NODE_DEFAULT_SHAPE": "box", + "NODE_BOX_OUTLINE_COLOR": "#e5eaf0", + "DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)", + "DEFAULT_GROUP_FONT": 24, + "WIDGET_BGCOLOR": "#2e3440", + "WIDGET_OUTLINE_COLOR": "#545d70", + "WIDGET_TEXT_COLOR": "#bcc2c8", + "WIDGET_SECONDARY_TEXT_COLOR": "#999", + "LINK_COLOR": "#9A9", + "EVENT_LINK_COLOR": "#A86", + "CONNECTING_LINK_COLOR": "#AFA" + }, + "comfy_base": { + "fg-color": "#e5eaf0", + "bg-color": "#2e3440", + "comfy-menu-bg": "#161b22", + "comfy-input-bg": "#2e3440", + "input-text": "#bcc2c8", + "descrip-text": "#999", + "drag-text": "#ccc", + "error-text": "#ff4444", + "border-color": "#545d70", + "tr-even-bg-color": "#2e3440", + "tr-odd-bg-color": "#161b22" + } + }, + }, + "github": { + "id": "github", + "name": "Github", + "colors": { + "node_slot": { + "BOOLEAN": "", + "CLIP": "#eacb8b", + "CLIP_VISION": "#A8DADC", + "CLIP_VISION_OUTPUT": "#ad7452", + "CONDITIONING": "#cf876f", + "CONTROL_NET": "#00d78d", + "CONTROL_NET_WEIGHTS": "", + "FLOAT": "", + "GLIGEN": "", + "IMAGE": "#80a1c0", + "IMAGEUPLOAD": "", + "INT": "", + "LATENT": "#b38ead", + "LATENT_KEYFRAME": "", + "MASK": "#a3bd8d", + "MODEL": "#8978a7", + "SAMPLER": "", + "SIGMAS": "", + "STRING": "", + "STYLE_MODEL": "#C2FFAE", + "T2I_ADAPTER_WEIGHTS": "", + "TAESD": "#DCC274", + "TIMESTEP_KEYFRAME": "", + "UPSCALE_MODEL": "", + "VAE": "#be616b" + }, + "litegraph_base": { + "BACKGROUND_IMAGE": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGlmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4gPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iQWRvYmUgWE1QIENvcmUgOS4xLWMwMDEgNzkuMTQ2Mjg5OSwgMjAyMy8wNi8yNS0yMDowMTo1NSAgICAgICAgIj4gPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4gPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIgeG1sbnM6eG1wPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvIiB4bWxuczpkYz0iaHR0cDovL3B1cmwub3JnL2RjL2VsZW1lbnRzLzEuMS8iIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIgeG1sbnM6eG1wTU09Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9tbS8iIHhtbG5zOnN0RXZ0PSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvc1R5cGUvUmVzb3VyY2VFdmVudCMiIHhtcDpDcmVhdG9yVG9vbD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIiB4bXA6Q3JlYXRlRGF0ZT0iMjAyMy0xMS0xM1QwMDoxODowMiswMTowMCIgeG1wOk1vZGlmeURhdGU9IjIwMjMtMTEtMTVUMDI6MDQ6NTkrMDE6MDAiIHhtcDpNZXRhZGF0YURhdGU9IjIwMjMtMTEtMTVUMDI6MDQ6NTkrMDE6MDAiIGRjOmZvcm1hdD0iaW1hZ2UvcG5nIiBwaG90b3Nob3A6Q29sb3JNb2RlPSIzIiB4bXBNTTpJbnN0YW5jZUlEPSJ4bXAuaWlkOmIyYzRhNjA5LWJmYTctYTg0MC1iOGFlLTk3MzE2ZjM1ZGIyNyIgeG1wTU06RG9jdW1lbnRJRD0iYWRvYmU6ZG9jaWQ6cGhvdG9zaG9wOjk0ZmNlZGU4LTE1MTctZmQ0MC04ZGU3LWYzOTgxM2E3ODk5ZiIgeG1wTU06T3JpZ2luYWxEb2N1bWVudElEPSJ4bXAuZGlkOjIzMWIxMGIwLWI0ZmItMDI0ZS1iMTJlLTMwNTMwM2NkMDdjOCI+IDx4bXBNTTpIaXN0b3J5PiA8cmRmOlNlcT4gPHJkZjpsaSBzdEV2dDphY3Rpb249ImNyZWF0ZWQiIHN0RXZ0Omluc3RhbmNlSUQ9InhtcC5paWQ6MjMxYjEwYjAtYjRmYi0wMjRlLWIxMmUtMzA1MzAzY2QwN2M4IiBzdEV2dDp3aGVuPSIyMDIzLTExLTEzVDAwOjE4OjAyKzAxOjAwIiBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZG9iZSBQaG90b3Nob3AgMjUuMSAoV2luZG93cykiLz4gPHJkZjpsaSBzdEV2dDphY3Rpb249InNhdmVkIiBzdEV2dDppbnN0YW5jZUlEPSJ4bXAuaWlkOjQ4OWY1NzlmLTJkNjUtZWQ0Zi04OTg0LTA4NGE2MGE1ZTMzNSIgc3RFdnQ6d2hlbj0iMjAyMy0xMS0xNVQwMjowNDo1OSswMTowMCIgc3RFdnQ6c29mdHdhcmVBZ2VudD0iQWRvYmUgUGhvdG9zaG9wIDI1LjEgKFdpbmRvd3MpIiBzdEV2dDpjaGFuZ2VkPSIvIi8+IDxyZGY6bGkgc3RFdnQ6YWN0aW9uPSJzYXZlZCIgc3RFdnQ6aW5zdGFuY2VJRD0ieG1wLmlpZDpiMmM0YTYwOS1iZmE3LWE4NDAtYjhhZS05NzMxNmYzNWRiMjciIHN0RXZ0OndoZW49IjIwMjMtMTEtMTVUMDI6MDQ6NTkrMDE6MDAiIHN0RXZ0OnNvZnR3YXJlQWdlbnQ9IkFkb2JlIFBob3Rvc2hvcCAyNS4xIChXaW5kb3dzKSIgc3RFdnQ6Y2hhbmdlZD0iLyIvPiA8L3JkZjpTZXE+IDwveG1wTU06SGlzdG9yeT4gPC9yZGY6RGVzY3JpcHRpb24+IDwvcmRmOlJERj4gPC94OnhtcG1ldGE+IDw/eHBhY2tldCBlbmQ9InIiPz4OTe6GAAAAx0lEQVR42u3WMQoAIQxFwRzJys77X8vSLiRgITif7bYbgrwYc/mKXyBoY4VVBgsWLFiwYFmOlTv+9jfDOjHmr8u6eVkGCxYsWLBgmc5S8ApewXvgYRksWLBgKXidpeBdloL3wMOCBctgwVLwCl7BuyyDBQsWLFiwTGcpeAWv4D3wsAwWLFiwFLzOUvAuS8F74GHBgmWwYCl4Ba/gXZbBggULFixYprMUvIJX8B54WAYLFixYCl5nKXiXpeA98LBgwTJYsGC9tg1o8f4TTtqzNQAAAABJRU5ErkJggg==", + "CLEAR_BACKGROUND_COLOR": "#040506", + "NODE_TITLE_COLOR": "#999", + "NODE_SELECTED_TITLE_COLOR": "#e5eaf0", + "NODE_TEXT_SIZE": 14, + "NODE_TEXT_COLOR": "#bcc2c8", + "NODE_SUBTEXT_SIZE": 12, + "NODE_DEFAULT_COLOR": "#161b22", + "NODE_DEFAULT_BGCOLOR": "#13171d", + "NODE_DEFAULT_BOXCOLOR": "#30363d", + "NODE_DEFAULT_SHAPE": "box", + "NODE_BOX_OUTLINE_COLOR": "#e5eaf0", + "DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)", + "DEFAULT_GROUP_FONT": 24, + "WIDGET_BGCOLOR": "#161b22", + "WIDGET_OUTLINE_COLOR": "#30363d", + "WIDGET_TEXT_COLOR": "#bcc2c8", + "WIDGET_SECONDARY_TEXT_COLOR": "#999", + "LINK_COLOR": "#9A9", + "EVENT_LINK_COLOR": "#A86", + "CONNECTING_LINK_COLOR": "#AFA" + }, + "comfy_base": { + "fg-color": "#e5eaf0", + "bg-color": "#161b22", + "comfy-menu-bg": "#13171d", + "comfy-input-bg": "#161b22", + "input-text": "#bcc2c8", + "descrip-text": "#999", + "drag-text": "#ccc", + "error-text": "#ff4444", + "border-color": "#30363d", + "tr-even-bg-color": "#161b22", + "tr-odd-bg-color": "#13171d" + } + }, } }; diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index 152cd7043..0a305391a 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -25,7 +25,7 @@ const ext = { requestAnimationFrame(() => { const currentNode = LGraphCanvas.active_canvas.current_node; const clickedComboValue = currentNode.widgets - .filter(w => w.type === "combo" && w.options.values.length === values.length) + ?.filter(w => w.type === "combo" && w.options.values.length === values.length) .find(w => w.options.values.every((v, i) => v === values[i])) ?.value; diff --git a/web/extensions/core/nodeTemplates.js b/web/extensions/core/nodeTemplates.js index 434491075..b6479f454 100644 --- a/web/extensions/core/nodeTemplates.js +++ b/web/extensions/core/nodeTemplates.js @@ -14,6 +14,9 @@ import { ComfyDialog, $el } from "../../scripts/ui.js"; // To delete/rename: // Right click the canvas // Node templates -> Manage +// +// To rearrange: +// Open the manage dialog and Drag and drop elements using the "Name:" label as handle const id = "Comfy.NodeTemplates"; @@ -22,6 +25,10 @@ class ManageTemplates extends ComfyDialog { super(); this.element.classList.add("comfy-manage-templates"); this.templates = this.load(); + this.draggedEl = null; + this.saveVisualCue = null; + this.emptyImg = new Image(); + this.emptyImg.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs='; this.importInput = $el("input", { type: "file", @@ -35,14 +42,11 @@ class ManageTemplates extends ComfyDialog { createButtons() { const btns = super.createButtons(); - btns[0].textContent = "Cancel"; - btns.unshift( - $el("button", { - type: "button", - textContent: "Save", - onclick: () => this.save(), - }) - ); + btns[0].textContent = "Close"; + btns[0].onclick = (e) => { + clearTimeout(this.saveVisualCue); + this.close(); + }; btns.unshift( $el("button", { type: "button", @@ -71,25 +75,6 @@ class ManageTemplates extends ComfyDialog { } } - save() { - // Find all visible inputs and save them as our new list - const inputs = this.element.querySelectorAll("input"); - const updated = []; - - for (let i = 0; i < inputs.length; i++) { - const input = inputs[i]; - if (input.parentElement.style.display !== "none") { - const t = this.templates[i]; - t.name = input.value.trim() || input.getAttribute("data-name"); - updated.push(t); - } - } - - this.templates = updated; - this.store(); - this.close(); - } - store() { localStorage.setItem(id, JSON.stringify(this.templates)); } @@ -145,71 +130,155 @@ class ManageTemplates extends ComfyDialog { super.show( $el( "div", - { - style: { - display: "grid", - gridTemplateColumns: "1fr auto", - gap: "5px", - }, - }, - this.templates.flatMap((t) => { + {}, + this.templates.flatMap((t,i) => { let nameInput; return [ $el( - "label", + "div", { - textContent: "Name: ", + dataset: { id: i }, + className: "tempateManagerRow", + style: { + display: "grid", + gridTemplateColumns: "1fr auto", + border: "1px dashed transparent", + gap: "5px", + backgroundColor: "var(--comfy-menu-bg)" + }, + ondragstart: (e) => { + this.draggedEl = e.currentTarget; + e.currentTarget.style.opacity = "0.6"; + e.currentTarget.style.border = "1px dashed yellow"; + e.dataTransfer.effectAllowed = 'move'; + e.dataTransfer.setDragImage(this.emptyImg, 0, 0); + }, + ondragend: (e) => { + e.target.style.opacity = "1"; + e.currentTarget.style.border = "1px dashed transparent"; + e.currentTarget.removeAttribute("draggable"); + + // rearrange the elements in the localStorage + this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { + var prev_i = el.dataset.id; + + if ( el == this.draggedEl && prev_i != i ) { + [this.templates[i], this.templates[prev_i]] = [this.templates[prev_i], this.templates[i]]; + } + el.dataset.id = i; + }); + this.store(); + }, + ondragover: (e) => { + e.preventDefault(); + if ( e.currentTarget == this.draggedEl ) + return; + + let rect = e.currentTarget.getBoundingClientRect(); + if (e.clientY > rect.top + rect.height / 2) { + e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget.nextSibling); + } else { + e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget); + } + } }, [ - $el("input", { - value: t.name, - dataset: { name: t.name }, - $: (el) => (nameInput = el), - }), + $el( + "label", + { + textContent: "Name: ", + style: { + cursor: "grab", + }, + onmousedown: (e) => { + // enable dragging only from the label + if (e.target.localName == 'label') + e.currentTarget.parentNode.draggable = 'true'; + } + }, + [ + $el("input", { + value: t.name, + dataset: { name: t.name }, + style: { + transitionProperty: 'background-color', + transitionDuration: '0s', + }, + onchange: (e) => { + clearTimeout(this.saveVisualCue); + var el = e.target; + var row = el.parentNode.parentNode; + this.templates[row.dataset.id].name = el.value.trim() || 'untitled'; + this.store(); + el.style.backgroundColor = 'rgb(40, 95, 40)'; + el.style.transitionDuration = '0s'; + this.saveVisualCue = setTimeout(function () { + el.style.transitionDuration = '.7s'; + el.style.backgroundColor = 'var(--comfy-input-bg)'; + }, 15); + }, + onkeypress: (e) => { + var el = e.target; + clearTimeout(this.saveVisualCue); + el.style.transitionDuration = '0s'; + el.style.backgroundColor = 'var(--comfy-input-bg)'; + }, + $: (el) => (nameInput = el), + }) + ] + ), + $el( + "div", + {}, + [ + $el("button", { + textContent: "Export", + style: { + fontSize: "12px", + fontWeight: "normal", + }, + onclick: (e) => { + const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string + const blob = new Blob([json], {type: "application/json"}); + const url = URL.createObjectURL(blob); + const a = $el("a", { + href: url, + download: (nameInput.value || t.name) + ".json", + style: {display: "none"}, + parent: document.body, + }); + a.click(); + setTimeout(function () { + a.remove(); + window.URL.revokeObjectURL(url); + }, 0); + }, + }), + $el("button", { + textContent: "Delete", + style: { + fontSize: "12px", + color: "red", + fontWeight: "normal", + }, + onclick: (e) => { + const item = e.target.parentNode.parentNode; + item.parentNode.removeChild(item); + this.templates.splice(item.dataset.id*1, 1); + this.store(); + // update the rows index, setTimeout ensures that the list is updated + var that = this; + setTimeout(function (){ + that.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { + el.dataset.id = i; + }); + }, 0); + }, + }), + ] + ), ] - ), - $el( - "div", - {}, - [ - $el("button", { - textContent: "Export", - style: { - fontSize: "12px", - fontWeight: "normal", - }, - onclick: (e) => { - const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string - const blob = new Blob([json], {type: "application/json"}); - const url = URL.createObjectURL(blob); - const a = $el("a", { - href: url, - download: (nameInput.value || t.name) + ".json", - style: {display: "none"}, - parent: document.body, - }); - a.click(); - setTimeout(function () { - a.remove(); - window.URL.revokeObjectURL(url); - }, 0); - }, - }), - $el("button", { - textContent: "Delete", - style: { - fontSize: "12px", - color: "red", - fontWeight: "normal", - }, - onclick: (e) => { - nameInput.value = ""; - e.target.parentElement.style.display = "none"; - e.target.parentElement.previousElementSibling.style.display = "none"; - }, - }), - ] - ), + ) ]; }) ) diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js new file mode 100644 index 000000000..c6613b0f0 --- /dev/null +++ b/web/extensions/core/undoRedo.js @@ -0,0 +1,150 @@ +import { app } from "../../scripts/app.js"; + +const MAX_HISTORY = 50; + +let undo = []; +let redo = []; +let activeState = null; +let isOurLoad = false; +function checkState() { + const currentState = app.graph.serialize(); + if (!graphEqual(activeState, currentState)) { + undo.push(activeState); + if (undo.length > MAX_HISTORY) { + undo.shift(); + } + activeState = clone(currentState); + redo.length = 0; + } +} + +const loadGraphData = app.loadGraphData; +app.loadGraphData = async function () { + const v = await loadGraphData.apply(this, arguments); + if (isOurLoad) { + isOurLoad = false; + } else { + checkState(); + } + return v; +}; + +function clone(obj) { + try { + if (typeof structuredClone !== "undefined") { + return structuredClone(obj); + } + } catch (error) { + // structuredClone is stricter than using JSON.parse/stringify so fallback to that + } + + return JSON.parse(JSON.stringify(obj)); +} + +function graphEqual(a, b, root = true) { + if (a === b) return true; + + if (typeof a == "object" && a && typeof b == "object" && b) { + const keys = Object.getOwnPropertyNames(a); + + if (keys.length != Object.getOwnPropertyNames(b).length) { + return false; + } + + for (const key of keys) { + let av = a[key]; + let bv = b[key]; + if (root && key === "nodes") { + // Nodes need to be sorted as the order changes when selecting nodes + av = [...av].sort((a, b) => a.id - b.id); + bv = [...bv].sort((a, b) => a.id - b.id); + } + if (!graphEqual(av, bv, false)) { + return false; + } + } + + return true; + } + + return false; +} + +const undoRedo = async (e) => { + if (e.ctrlKey || e.metaKey) { + if (e.key === "y") { + const prevState = redo.pop(); + if (prevState) { + undo.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState); + activeState = prevState; + } + return true; + } else if (e.key === "z") { + const prevState = undo.pop(); + if (prevState) { + redo.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState); + activeState = prevState; + } + return true; + } + } +}; + +const bindInput = (activeEl) => { + if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") { + for (const evt of ["change", "input", "blur"]) { + if (`on${evt}` in activeEl) { + const listener = () => { + checkState(); + activeEl.removeEventListener(evt, listener); + }; + activeEl.addEventListener(evt, listener); + return true; + } + } + } +}; + +window.addEventListener( + "keydown", + (e) => { + requestAnimationFrame(async () => { + const activeEl = document.activeElement; + if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { + // Ignore events on inputs, they have their native history + return; + } + + // Check if this is a ctrl+z ctrl+y + if (await undoRedo(e)) return; + + // If our active element is some type of input then handle changes after they're done + if (bindInput(activeEl)) return; + checkState(); + }); + }, + true +); + +// Handle clicking DOM elements (e.g. widgets) +window.addEventListener("mouseup", () => { + checkState(); +}); + +// Handle litegraph clicks +const processMouseUp = LGraphCanvas.prototype.processMouseUp; +LGraphCanvas.prototype.processMouseUp = function (e) { + const v = processMouseUp.apply(this, arguments); + checkState(); + return v; +}; +const processMouseDown = LGraphCanvas.prototype.processMouseDown; +LGraphCanvas.prototype.processMouseDown = function (e) { + const v = processMouseDown.apply(this, arguments); + checkState(); + return v; +}; diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index c370f4f3c..53e3a7a28 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -1,4 +1,4 @@ -import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js"; +import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js"; import { app } from "../../scripts/app.js"; const CONVERTED_TYPE = "converted-widget"; @@ -467,7 +467,11 @@ app.registerExtension({ if (!control_value) { control_value = "fixed"; } - addValueControlWidget(this, widget, control_value); + addValueControlWidgets(this, widget, control_value); + let filter = this.widgets_values?.[2]; + if(filter && this.widgets.length === 3) { + this.widgets[2].value = filter; + } } // When our value changes, update other widgets to reflect our changes diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index a56bf4db9..447b13c8d 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -2533,7 +2533,7 @@ var w = this.widgets[i]; if(!w) continue; - if(w.options && w.options.property && this.properties[ w.options.property ]) + if(w.options && w.options.property && (this.properties[ w.options.property ] != undefined)) w.value = JSON.parse( JSON.stringify( this.properties[ w.options.property ] ) ); } if (info.widgets_values) { @@ -5717,10 +5717,10 @@ LGraphNode.prototype.executeAction = function(action) * @method enableWebGL **/ LGraphCanvas.prototype.enableWebGL = function() { - if (typeof GL === undefined) { + if (typeof GL === "undefined") { throw "litegl.js must be included to use a WebGL canvas"; } - if (typeof enableWebGLCanvas === undefined) { + if (typeof enableWebGLCanvas === "undefined") { throw "webglCanvas.js must be included to use this feature"; } @@ -7113,15 +7113,16 @@ LGraphNode.prototype.executeAction = function(action) } }; - LGraphCanvas.prototype.copyToClipboard = function() { + LGraphCanvas.prototype.copyToClipboard = function(nodes) { var clipboard_info = { nodes: [], links: [] }; var index = 0; var selected_nodes_array = []; - for (var i in this.selected_nodes) { - var node = this.selected_nodes[i]; + if (!nodes) nodes = this.selected_nodes; + for (var i in nodes) { + var node = nodes[i]; if (node.clonable === false) continue; node._relative_id = index; @@ -11705,7 +11706,7 @@ LGraphNode.prototype.executeAction = function(action) default: iS = 0; // try with first if no name set } - if (typeof options.node_from.outputs[iS] !== undefined){ + if (typeof options.node_from.outputs[iS] !== "undefined"){ if (iS!==false && iS>-1){ options.node_from.connectByType( iS, node, options.node_from.outputs[iS].type ); } @@ -11733,7 +11734,7 @@ LGraphNode.prototype.executeAction = function(action) default: iS = 0; // try with first if no name set } - if (typeof options.node_to.inputs[iS] !== undefined){ + if (typeof options.node_to.inputs[iS] !== "undefined"){ if (iS!==false && iS>-1){ // try connection options.node_to.connectByTypeOutput(iS,node,options.node_to.inputs[iS].type); diff --git a/web/scripts/api.js b/web/scripts/api.js index 188d7a28d..5e2ce76b1 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -254,9 +254,9 @@ class ComfyApi extends EventTarget { * Gets the prompt execution history * @returns Prompt history including node outputs */ - async getHistory() { + async getHistory(max_items=200) { try { - const res = await this.fetchApi("/history"); + const res = await this.fetchApi(`/history?max_items=${max_items}`); return { History: Object.values(await res.json()) }; } catch (error) { console.error(error); diff --git a/web/scripts/app.js b/web/scripts/app.js index 6da924c4f..d454e88f3 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -3,7 +3,26 @@ import { ComfyWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; -import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; +import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; +import { addDomClippingSetting } from "./domWidget.js"; +import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js" + +export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview" + +function sanitizeNodeName(string) { + let entityMap = { + '&': '', + '<': '', + '>': '', + '"': '', + "'": '', + '`': '', + '=': '' + }; + return String(string).replace(/[&<>"'`=]/g, function fromEntityMap (s) { + return entityMap[s]; + }); +} /** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension @@ -389,7 +408,9 @@ export class ComfyApp { return shiftY; } - node.prototype.setSizeForImage = function () { + node.prototype.setSizeForImage = function (force) { + if(!force && this.animatedImages) return; + if (this.inputHeight) { this.setSize(this.size); return; @@ -406,13 +427,20 @@ export class ComfyApp { let imagesChanged = false const output = app.nodeOutputs[this.id + ""]; - if (output && output.images) { + if (output?.images) { + this.animatedImages = output?.animated?.find(Boolean); if (this.images !== output.images) { this.images = output.images; imagesChanged = true; - imgURLs = imgURLs.concat(output.images.map(params => { - return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam()); - })) + imgURLs = imgURLs.concat( + output.images.map((params) => { + return api.apiURL( + "/view?" + + new URLSearchParams(params).toString() + + (this.animatedImages ? "" : app.getPreviewFormatParam()) + ); + }) + ); } } @@ -491,7 +519,35 @@ export class ComfyApp { return true; } - if (this.imgs && this.imgs.length) { + if (this.imgs?.length) { + const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET); + + if(this.animatedImages) { + // Instead of using the canvas we'll use a IMG + if(widgetIdx > -1) { + // Replace content + const widget = this.widgets[widgetIdx]; + widget.options.host.updateImages(this.imgs); + } else { + const host = createImageHost(this); + this.setSizeForImage(true); + const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, { + host, + getHeight: host.getHeight, + onDraw: host.onDraw, + hideOnZoom: false + }); + widget.serializeValue = () => undefined; + widget.options.host.updateImages(this.imgs); + } + return; + } + + if (widgetIdx > -1) { + this.widgets[widgetIdx].onRemove?.(); + this.widgets.splice(widgetIdx, 1); + } + const canvas = app.graph.list_of_graphcanvas[0]; const mouse = canvas.graph_mouse; if (!canvas.pointer_is_down && this.pointerDown) { @@ -531,31 +587,7 @@ export class ComfyApp { } else { cell_padding = 0; - let best = 0; - let w = this.imgs[0].naturalWidth; - let h = this.imgs[0].naturalHeight; - - // compact style - for (let c = 1; c <= numImages; c++) { - const rows = Math.ceil(numImages / c); - const cW = dw / c; - const cH = dh / rows; - const scaleX = cW / w; - const scaleY = cH / h; - - const scale = Math.min(scaleX, scaleY, 1); - const imageW = w * scale; - const imageH = h * scale; - const area = imageW * imageH * numImages; - - if (area > best) { - best = area; - cellWidth = imageW; - cellHeight = imageH; - cols = c; - shiftX = c * ((cW - imageW) / 2); - } - } + ({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh)); } let anyHovered = false; @@ -1256,6 +1288,7 @@ export class ComfyApp { canvasEl.tabIndex = "1"; document.body.prepend(canvasEl); + addDomClippingSetting(); this.#addProcessMouseHandler(); this.#addProcessKeyHandler(); this.#addConfigureHandler(); @@ -1453,6 +1486,17 @@ export class ComfyApp { localStorage.setItem("litegrapheditor_clipboard", old); } + showMissingNodesError(missingNodeTypes, hasAddedNodes = true) { + this.ui.dialog.show( + `When loading the graph, the following node types were not found: ${hasAddedNodes ? "Nodes that have failed to load will show as red on the graph." : ""}` + ); + this.logging.addEntry("Comfy.App", "warn", { + MissingNodes: missingNodeTypes, + }); + } + /** * Populates the graph with the specified workflow data * @param {*} graphData A serialized graph object @@ -1462,24 +1506,28 @@ export class ComfyApp { let reset_invalid_values = false; if (!graphData) { - if (typeof structuredClone === "undefined") - { - graphData = JSON.parse(JSON.stringify(defaultGraph)); - }else - { - graphData = structuredClone(defaultGraph); - } + graphData = defaultGraph; reset_invalid_values = true; } + if (typeof structuredClone === "undefined") + { + graphData = JSON.parse(JSON.stringify(graphData)); + }else + { + graphData = structuredClone(graphData); + } + const missingNodeTypes = []; for (let n of graphData.nodes) { // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader"; if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix + if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix // Find missing node types if (!(n.type in LiteGraph.registered_node_types)) { + n.type = sanitizeNodeName(n.type); missingNodeTypes.push(n.type); } } @@ -1570,14 +1618,7 @@ export class ComfyApp { } if (missingNodeTypes.length) { - this.ui.dialog.show( - `When loading the graph, the following node types were not found: Nodes that have failed to load will show as red on the graph.` - ); - this.logging.addEntry("Comfy.App", "warn", { - MissingNodes: missingNodeTypes, - }); + this.showMissingNodesError(missingNodeTypes); } } @@ -1589,7 +1630,17 @@ export class ComfyApp { * @returns The workflow and node links */ async graphToPrompt(graph=this.graph, nodeIdOffset=0, path="/", globalMappings={}) { - const workflow = graph.serialize(); + for (const node of this.graph.computeExecutionOrder(false)) { + if (node.isVirtualNode) { + // Don't serialize frontend only nodes but let them make changes + if (node.applyToGraph) { + node.applyToGraph(); + } + continue; + } + } + + const workflow = this.graph.serialize(); const output = {}; let childNodeCount = 0; @@ -1598,10 +1649,6 @@ export class ComfyApp { const n = workflow.nodes.find((n) => n.id === node.id); if (node.isVirtualNode) { - // Don't serialize frontend only nodes but let them make changes - if (node.applyToGraph) { - node.applyToGraph(workflow); - } continue; } @@ -1862,12 +1909,23 @@ export class ComfyApp { importA1111(this.graph, pngInfo.parameters); } } + } else if (file.type === "image/webp") { + const pngInfo = await getWebpMetadata(file); + if (pngInfo) { + if (pngInfo.workflow) { + this.loadGraphData(JSON.parse(pngInfo.workflow)); + } else if (pngInfo.Workflow) { + this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node. + } + } } else if (file.type === "application/json" || file.name?.endsWith(".json")) { const reader = new FileReader(); reader.onload = () => { - var jsonContent = JSON.parse(reader.result); + const jsonContent = JSON.parse(reader.result); if (jsonContent?.templates) { this.loadTemplateData(jsonContent); + } else if(this.isApiJson(jsonContent)) { + this.loadApiJson(jsonContent); } else { this.loadGraphData(jsonContent); } @@ -1881,6 +1939,51 @@ export class ComfyApp { } } + isApiJson(data) { + return Object.values(data).every((v) => v.class_type); + } + + loadApiJson(apiData) { + const missingNodeTypes = Object.values(apiData).filter((n) => !LiteGraph.registered_node_types[n.class_type]); + if (missingNodeTypes.length) { + this.showMissingNodesError(missingNodeTypes.map(t => t.class_type), false); + return; + } + + const ids = Object.keys(apiData); + app.graph.clear(); + for (const id of ids) { + const data = apiData[id]; + const node = LiteGraph.createNode(data.class_type); + node.id = isNaN(+id) ? id : +id; + graph.add(node); + } + + for (const id of ids) { + const data = apiData[id]; + const node = app.graph.getNodeById(id); + for (const input in data.inputs ?? {}) { + const value = data.inputs[input]; + if (value instanceof Array) { + const [fromId, fromSlot] = value; + const fromNode = app.graph.getNodeById(fromId); + const toSlot = node.inputs?.findIndex((inp) => inp.name === input); + if (toSlot !== -1) { + fromNode.connect(fromSlot, node, toSlot); + } + } else { + const widget = node.widgets?.find((w) => w.name === input); + if (widget) { + widget.value = value; + widget.callback?.(value); + } + } + } + } + + app.graph.arrange(); + } + /** * Registers a Comfy web extension with the app * @param {ComfyExtension} extension diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js new file mode 100644 index 000000000..07da591cb --- /dev/null +++ b/web/scripts/domWidget.js @@ -0,0 +1,323 @@ +import { app, ANIM_PREVIEW_WIDGET } from "./app.js"; + +const SIZE = Symbol(); + +function intersect(a, b) { + const x = Math.max(a.x, b.x); + const num1 = Math.min(a.x + a.width, b.x + b.width); + const y = Math.max(a.y, b.y); + const num2 = Math.min(a.y + a.height, b.y + b.height); + if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y]; + else return null; +} + +function getClipPath(node, element, elRect) { + const selectedNode = Object.values(app.canvas.selected_nodes)[0]; + if (selectedNode && selectedNode !== node) { + const MARGIN = 7; + const scale = app.canvas.ds.scale; + + const bounding = selectedNode.getBounding(); + const intersection = intersect( + { x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale }, + { + x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN, + y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN, + width: bounding[2] + MARGIN + MARGIN, + height: bounding[3] + MARGIN + MARGIN, + } + ); + + if (!intersection) { + return ""; + } + + const widgetRect = element.getBoundingClientRect(); + const clipX = intersection[0] - widgetRect.x / scale + "px"; + const clipY = intersection[1] - widgetRect.y / scale + "px"; + const clipWidth = intersection[2] + "px"; + const clipHeight = intersection[3] + "px"; + const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`; + return path; + } + return ""; +} + +function computeSize(size) { + if (this.widgets?.[0].last_y == null) return; + + let y = this.widgets[0].last_y; + let freeSpace = size[1] - y; + + let widgetHeight = 0; + let dom = []; + for (const w of this.widgets) { + if (w.type === "converted-widget") { + // Ignore + delete w.computedHeight; + } else if (w.computeSize) { + widgetHeight += w.computeSize()[1] + 4; + } else if (w.element) { + // Extract DOM widget size info + const styles = getComputedStyle(w.element); + let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height")); + let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height")); + + let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height"); + if (prefHeight.endsWith?.("%")) { + prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100); + } else { + prefHeight = parseInt(prefHeight); + if (isNaN(minHeight)) { + minHeight = prefHeight; + } + } + if (isNaN(minHeight)) { + minHeight = 50; + } + if (!isNaN(maxHeight)) { + if (!isNaN(prefHeight)) { + prefHeight = Math.min(prefHeight, maxHeight); + } else { + prefHeight = maxHeight; + } + } + dom.push({ + minHeight, + prefHeight, + w, + }); + } else { + widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } + + freeSpace -= widgetHeight; + + // Calculate sizes with all widgets at their min height + const prefGrow = []; // Nodes that want to grow to their prefd size + const canGrow = []; // Nodes that can grow to auto size + let growBy = 0; + for (const d of dom) { + freeSpace -= d.minHeight; + if (isNaN(d.prefHeight)) { + canGrow.push(d); + d.w.computedHeight = d.minHeight; + } else { + const diff = d.prefHeight - d.minHeight; + if (diff > 0) { + prefGrow.push(d); + growBy += diff; + d.diff = diff; + } else { + d.w.computedHeight = d.minHeight; + } + } + } + + if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) { + // Allocate space for image + freeSpace -= 220; + } + + if (freeSpace < 0) { + // Not enough space for all widgets so we need to grow + size[1] -= freeSpace; + this.graph.setDirtyCanvas(true); + } else { + // Share the space between each + const growDiff = freeSpace - growBy; + if (growDiff > 0) { + // All pref sizes can be fulfilled + freeSpace = growDiff; + for (const d of prefGrow) { + d.w.computedHeight = d.prefHeight; + } + } else { + // We need to grow evenly + const shared = -growDiff / prefGrow.length; + for (const d of prefGrow) { + d.w.computedHeight = d.prefHeight - shared; + } + freeSpace = 0; + } + + if (freeSpace > 0 && canGrow.length) { + // Grow any that are auto height + const shared = freeSpace / canGrow.length; + for (const d of canGrow) { + d.w.computedHeight += shared; + } + } + } + + // Position each of the widgets + for (const w of this.widgets) { + w.y = y; + if (w.computedHeight) { + y += w.computedHeight; + } else if (w.computeSize) { + y += w.computeSize()[1] + 4; + } else { + y += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } +} + +// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen +const elementWidgets = new Set(); +const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes; +LGraphCanvas.prototype.computeVisibleNodes = function () { + const visibleNodes = computeVisibleNodes.apply(this, arguments); + for (const node of app.graph._nodes) { + if (elementWidgets.has(node)) { + const hidden = visibleNodes.indexOf(node) === -1; + for (const w of node.widgets) { + if (w.element) { + w.element.hidden = hidden; + if (hidden) { + w.options.onHide?.(w); + } + } + } + } + } + + return visibleNodes; +}; + +let enableDomClipping = true; + +export function addDomClippingSetting() { + app.ui.settings.addSetting({ + id: "Comfy.DOMClippingEnabled", + name: "Enable DOM element clipping (enabling may reduce performance)", + type: "boolean", + defaultValue: enableDomClipping, + onChange(value) { + console.log("enableDomClipping", enableDomClipping); + enableDomClipping = !!value; + }, + }); +} + +LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { + options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options }; + + if (!element.parentElement) { + document.body.append(element); + } + + let mouseDownHandler; + if (element.blur) { + mouseDownHandler = (event) => { + if (!element.contains(event.target)) { + element.blur(); + } + }; + document.addEventListener("mousedown", mouseDownHandler); + } + + const widget = { + type, + name, + get value() { + return options.getValue?.() ?? undefined; + }, + set value(v) { + options.setValue?.(v); + widget.callback?.(widget.value); + }, + draw: function (ctx, node, widgetWidth, y, widgetHeight) { + if (widget.computedHeight == null) { + computeSize.call(node, node.size); + } + + const hidden = + node.flags?.collapsed || + (!!options.hideOnZoom && app.canvas.ds.scale < 0.5) || + widget.computedHeight <= 0 || + widget.type === "converted-widget"; + element.hidden = hidden; + element.style.display = hidden ? "none" : null; + if (hidden) { + widget.options.onHide?.(widget); + return; + } + + const margin = 10; + const elRect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + + const scale = new DOMMatrix().scaleSelf(transform.a, transform.d); + + Object.assign(element.style, { + transformOrigin: "0 0", + transform: scale, + left: `${transform.a + transform.e}px`, + top: `${transform.d + transform.f}px`, + width: `${widgetWidth - margin * 2}px`, + height: `${(widget.computedHeight ?? 50) - margin * 2}px`, + position: "absolute", + zIndex: app.graph._nodes.indexOf(node), + }); + + if (enableDomClipping) { + element.style.clipPath = getClipPath(node, element, elRect); + element.style.willChange = "clip-path"; + } + + this.options.onDraw?.(widget); + }, + element, + options, + onRemove() { + if (mouseDownHandler) { + document.removeEventListener("mousedown", mouseDownHandler); + } + element.remove(); + }, + }; + + for (const evt of options.selectOn) { + element.addEventListener(evt, () => { + app.canvas.selectNode(this); + app.canvas.bringToFront(this); + }); + } + + this.addCustomWidget(widget); + elementWidgets.add(this); + + const collapse = this.collapse; + this.collapse = function() { + collapse.apply(this, arguments); + if(this.flags?.collapsed) { + element.hidden = true; + element.style.display = "none"; + } + } + + const onRemoved = this.onRemoved; + this.onRemoved = function () { + element.remove(); + elementWidgets.delete(this); + onRemoved?.apply(this, arguments); + }; + + if (!this[SIZE]) { + this[SIZE] = true; + const onResize = this.onResize; + this.onResize = function (size) { + options.beforeResize?.call(widget, this); + computeSize.call(this, size); + onResize?.apply(this, arguments); + options.afterResize?.call(widget, this); + }; + } + + return widget; +}; diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index c5293dfa3..83a4ebc86 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -24,7 +24,7 @@ export function getPngMetadata(file) { const length = dataView.getUint32(offset); // Get the chunk type const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8)); - if (type === "tEXt") { + if (type === "tEXt" || type == "comf") { // Get the keyword let keyword_end = offset + 8; while (pngData[keyword_end] !== 0) { @@ -47,6 +47,105 @@ export function getPngMetadata(file) { }); } +function parseExifData(exifData) { + // Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian) + const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949; + + // Function to read 16-bit and 32-bit integers from binary data + function readInt(offset, isLittleEndian, length) { + let arr = exifData.slice(offset, offset + length) + if (length === 2) { + return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint16(0, isLittleEndian); + } else if (length === 4) { + return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint32(0, isLittleEndian); + } + } + + // Read the offset to the first IFD (Image File Directory) + const ifdOffset = readInt(4, isLittleEndian, 4); + + function parseIFD(offset) { + const numEntries = readInt(offset, isLittleEndian, 2); + const result = {}; + + for (let i = 0; i < numEntries; i++) { + const entryOffset = offset + 2 + i * 12; + const tag = readInt(entryOffset, isLittleEndian, 2); + const type = readInt(entryOffset + 2, isLittleEndian, 2); + const numValues = readInt(entryOffset + 4, isLittleEndian, 4); + const valueOffset = readInt(entryOffset + 8, isLittleEndian, 4); + + // Read the value(s) based on the data type + let value; + if (type === 2) { + // ASCII string + value = String.fromCharCode(...exifData.slice(valueOffset, valueOffset + numValues - 1)); + } + + result[tag] = value; + } + + return result; + } + + // Parse the first IFD + const ifdData = parseIFD(ifdOffset); + return ifdData; +} + +function splitValues(input) { + var output = {}; + for (var key in input) { + var value = input[key]; + var splitValues = value.split(':', 2); + output[splitValues[0]] = splitValues[1]; + } + return output; +} + +export function getWebpMetadata(file) { + return new Promise((r) => { + const reader = new FileReader(); + reader.onload = (event) => { + const webp = new Uint8Array(event.target.result); + const dataView = new DataView(webp.buffer); + + // Check that the WEBP signature is present + if (dataView.getUint32(0) !== 0x52494646 || dataView.getUint32(8) !== 0x57454250) { + console.error("Not a valid WEBP file"); + r(); + return; + } + + // Start searching for chunks after the WEBP signature + let offset = 12; + let txt_chunks = {}; + // Loop through the chunks in the WEBP file + while (offset < webp.length) { + const chunk_length = dataView.getUint32(offset + 4, true); + const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4)); + if (chunk_type === "EXIF") { + if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") { + offset += 6; + } + let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length)); + for (var key in data) { + var value = data[key]; + let index = value.indexOf(':'); + txt_chunks[value.slice(0, index)] = value.slice(index + 1); + } + } + + offset += 8 + chunk_length; + } + + r(txt_chunks); + }; + + reader.readAsArrayBuffer(file); + }); +} + export function getLatentMetadata(file) { return new Promise((r) => { const reader = new FileReader(); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index c3b3fbda1..8a58d30b3 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -599,7 +599,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png,.latent,.safetensors", + accept: ".json,image/png,.latent,.safetensors,image/webp", style: {display: "none"}, parent: document.body, onchange: () => { @@ -719,20 +719,22 @@ export class ComfyUI { filename += ".json"; } } - const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string - const blob = new Blob([json], {type: "application/json"}); - const url = URL.createObjectURL(blob); - const a = $el("a", { - href: url, - download: filename, - style: {display: "none"}, - parent: document.body, + app.graphToPrompt().then(p=>{ + const json = JSON.stringify(p.workflow, null, 2); // convert the data to a JSON string + const blob = new Blob([json], {type: "application/json"}); + const url = URL.createObjectURL(blob); + const a = $el("a", { + href: url, + download: filename, + style: {display: "none"}, + parent: document.body, + }); + a.click(); + setTimeout(function () { + a.remove(); + window.URL.revokeObjectURL(url); + }, 0); }); - a.click(); - setTimeout(function () { - a.remove(); - window.URL.revokeObjectURL(url); - }, 0); }, }), $el("button", { diff --git a/web/scripts/ui/imagePreview.js b/web/scripts/ui/imagePreview.js new file mode 100644 index 000000000..2a7f66b8f --- /dev/null +++ b/web/scripts/ui/imagePreview.js @@ -0,0 +1,97 @@ +import { $el } from "../ui.js"; + +export function calculateImageGrid(imgs, dw, dh) { + let best = 0; + let w = imgs[0].naturalWidth; + let h = imgs[0].naturalHeight; + const numImages = imgs.length; + + let cellWidth, cellHeight, cols, rows, shiftX; + // compact style + for (let c = 1; c <= numImages; c++) { + const r = Math.ceil(numImages / c); + const cW = dw / c; + const cH = dh / r; + const scaleX = cW / w; + const scaleY = cH / h; + + const scale = Math.min(scaleX, scaleY, 1); + const imageW = w * scale; + const imageH = h * scale; + const area = imageW * imageH * numImages; + + if (area > best) { + best = area; + cellWidth = imageW; + cellHeight = imageH; + cols = c; + rows = r; + shiftX = c * ((cW - imageW) / 2); + } + } + + return { cellWidth, cellHeight, cols, rows, shiftX }; +} + +export function createImageHost(node) { + const el = $el("div.comfy-img-preview"); + let currentImgs; + let first = true; + + function updateSize() { + let w = null; + let h = null; + + if (currentImgs) { + let elH = el.clientHeight; + if (first) { + first = false; + // On first run, if we are small then grow a bit + if (elH < 190) { + elH = 190; + } + el.style.setProperty("--comfy-widget-min-height", elH); + } else { + el.style.setProperty("--comfy-widget-min-height", null); + } + + const nw = node.size[0]; + ({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH)); + w += "px"; + h += "px"; + + el.style.setProperty("--comfy-img-preview-width", w); + el.style.setProperty("--comfy-img-preview-height", h); + } + } + return { + el, + updateImages(imgs) { + if (imgs !== currentImgs) { + if (currentImgs == null) { + requestAnimationFrame(() => { + updateSize(); + }); + } + el.replaceChildren(...imgs); + currentImgs = imgs; + node.onResize(node.size); + node.graph.setDirtyCanvas(true, true); + } + }, + getHeight() { + updateSize(); + }, + onDraw() { + // Element from point uses a hittest find elements so we need to toggle pointer events + el.style.pointerEvents = "all"; + const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]); + el.style.pointerEvents = "none"; + + if(!over) return; + // Set the overIndex so Open Image etc work + const idx = currentImgs.indexOf(over); + node.overIndex = idx; + }, + }; +} diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index ba5a0a399..418084358 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,5 +1,6 @@ import { api } from "./api.js" import { getPngMetadata } from "./pnginfo.js"; +import "./domWidget.js"; function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { let defaultVal = inputData[1]["default"]; @@ -24,17 +25,58 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { } export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) { - const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, { + const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, { + addFilterList: false, + }); + return widgets[0]; +} + +export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) { + if (!options) options = {}; + + const widgets = []; + const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, { values: ["fixed", "increment", "decrement", "randomize"], serialize: false, // Don't include this in prompt. }); - valueControl.afterQueued = () => { + widgets.push(valueControl); + const isCombo = targetWidget.type === "combo"; + let comboFilter; + if (isCombo && options.addFilterList !== false) { + comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, { + serialize: false, // Don't include this in prompt. + }); + widgets.push(comboFilter); + } + + valueControl.afterQueued = () => { var v = valueControl.value; - if (targetWidget.type == "combo" && v !== "fixed") { - let current_index = targetWidget.options.values.indexOf(targetWidget.value); - let current_length = targetWidget.options.values.length; + if (isCombo && v !== "fixed") { + let values = targetWidget.options.values; + const filter = comboFilter?.value; + if (filter) { + let check; + if (filter.startsWith("/") && filter.endsWith("/")) { + try { + const regex = new RegExp(filter.substring(1, filter.length - 1)); + check = (item) => regex.test(item); + } catch (error) { + console.error("Error constructing RegExp filter for node " + node.id, filter, error); + } + } + if (!check) { + const lower = filter.toLocaleLowerCase(); + check = (item) => item.toLocaleLowerCase().includes(lower); + } + values = values.filter(item => check(item)); + if (!values.length && targetWidget.options.values.length) { + console.warn("Filter for node " + node.id + " has filtered out all items", filter); + } + } + let current_index = values.indexOf(targetWidget.value); + let current_length = values.length; switch (v) { case "increment": @@ -51,7 +93,7 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random current_index = Math.max(0, current_index); current_index = Math.min(current_length - 1, current_index); if (current_index >= 0) { - let value = targetWidget.options.values[current_index]; + let value = values[current_index]; targetWidget.value = value; targetWidget.callback(value); } @@ -88,7 +130,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random targetWidget.callback(targetWidget.value); } } - return valueControl; + + return widgets; }; function seedWidget(node, inputName, inputData, app) { @@ -98,166 +141,21 @@ function seedWidget(node, inputName, inputData, app) { seed.widget.linkedWidgets = [seedControl]; return seed; } - -const MultilineSymbol = Symbol(); -const MultilineResizeSymbol = Symbol(); - function addMultilineWidget(node, name, opts, app) { - const MIN_SIZE = 50; + const inputEl = document.createElement("textarea"); + inputEl.className = "comfy-multiline-input"; + inputEl.value = opts.defaultVal; + inputEl.placeholder = opts.placeholder || ""; - function computeSize(size) { - if (node.widgets[0].last_y == null) return; - - let y = node.widgets[0].last_y; - let freeSpace = size[1] - y; - - // Compute the height of all non customtext widgets - let widgetHeight = 0; - const multi = []; - for (let i = 0; i < node.widgets.length; i++) { - const w = node.widgets[i]; - if (w.type === "customtext") { - multi.push(w); - } else { - if (w.computeSize) { - widgetHeight += w.computeSize()[1] + 4; - } else { - widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4; - } - } - } - - // See how large each text input can be - freeSpace -= widgetHeight; - freeSpace /= multi.length + (!!node.imgs?.length); - - if (freeSpace < MIN_SIZE) { - // There isnt enough space for all the widgets, increase the size of the node - freeSpace = MIN_SIZE; - node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length)); - node.graph.setDirtyCanvas(true); - } - - // Position each of the widgets - for (const w of node.widgets) { - w.y = y; - if (w.type === "customtext") { - y += freeSpace; - w.computedHeight = freeSpace - multi.length*4; - } else if (w.computeSize) { - y += w.computeSize()[1] + 4; - } else { - y += LiteGraph.NODE_WIDGET_HEIGHT + 4; - } - } - - node.inputHeight = freeSpace; - } - - const widget = { - type: "customtext", - name, - get value() { - return this.inputEl.value; + const widget = node.addDOMWidget(name, "customtext", inputEl, { + getValue() { + return inputEl.value; }, - set value(x) { - this.inputEl.value = x; + setValue(v) { + inputEl.value = v; }, - draw: function (ctx, _, widgetWidth, y, widgetHeight) { - if (!this.parent.inputHeight) { - // If we are initially offscreen when created we wont have received a resize event - // Calculate it here instead - computeSize(node.size); - } - const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext"; - const margin = 10; - const elRect = ctx.canvas.getBoundingClientRect(); - const transform = new DOMMatrix() - .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) - .multiplySelf(ctx.getTransform()) - .translateSelf(margin, margin + y); - - const scale = new DOMMatrix().scaleSelf(transform.a, transform.d) - Object.assign(this.inputEl.style, { - transformOrigin: "0 0", - transform: scale, - left: `${transform.a + transform.e}px`, - top: `${transform.d + transform.f}px`, - width: `${widgetWidth - (margin * 2)}px`, - height: `${this.parent.inputHeight - (margin * 2)}px`, - position: "absolute", - background: (!node.color)?'':node.color, - color: (!node.color)?'':'white', - zIndex: app.graph._nodes.indexOf(node), - }); - this.inputEl.hidden = !visible; - }, - }; - widget.inputEl = document.createElement("textarea"); - widget.inputEl.className = "comfy-multiline-input"; - widget.inputEl.value = opts.defaultVal; - widget.inputEl.placeholder = opts.placeholder || ""; - document.addEventListener("mousedown", function (event) { - if (!widget.inputEl.contains(event.target)) { - widget.inputEl.blur(); - } }); - widget.parent = node; - document.body.appendChild(widget.inputEl); - - node.addCustomWidget(widget); - - app.canvas.onDrawBackground = function () { - // Draw node isnt fired once the node is off the screen - // if it goes off screen quickly, the input may not be removed - // this shifts it off screen so it can be moved back if the node is visible. - for (let n in app.graph._nodes) { - n = graph._nodes[n]; - for (let w in n.widgets) { - let wid = n.widgets[w]; - if (Object.hasOwn(wid, "inputEl")) { - wid.inputEl.style.left = -8000 + "px"; - wid.inputEl.style.position = "absolute"; - } - } - } - }; - - node.onRemoved = function () { - // When removing this node we need to remove the input from the DOM - for (let y in this.widgets) { - if (this.widgets[y].inputEl) { - this.widgets[y].inputEl.remove(); - } - } - }; - - widget.onRemove = () => { - widget.inputEl?.remove(); - - // Restore original size handler if we are the last - if (!--node[MultilineSymbol]) { - node.onResize = node[MultilineResizeSymbol]; - delete node[MultilineSymbol]; - delete node[MultilineResizeSymbol]; - } - }; - - if (node[MultilineSymbol]) { - node[MultilineSymbol]++; - } else { - node[MultilineSymbol] = 1; - const onResize = (node[MultilineResizeSymbol] = node.onResize); - - node.onResize = function (size) { - computeSize(size); - - // Call original resizer handler - if (onResize) { - onResize.apply(this, arguments); - } - }; - } + widget.inputEl = inputEl; return { minWidth: 400, minHeight: 200, widget }; } @@ -306,14 +204,23 @@ export const ComfyWidgets = { }; }, BOOLEAN(node, inputName, inputData) { - let defaultVal = inputData[1]["default"]; + let defaultVal = false; + let options = {}; + if (inputData[1]) { + if (inputData[1].default) + defaultVal = inputData[1].default; + if (inputData[1].label_on) + options["on"] = inputData[1].label_on; + if (inputData[1].label_off) + options["off"] = inputData[1].label_off; + } return { widget: node.addWidget( "toggle", inputName, defaultVal, () => {}, - {"on": inputData[1].label_on, "off": inputData[1].label_off} + options, ) }; }, diff --git a/web/style.css b/web/style.css index 692fa31d6..378fe0a48 100644 --- a/web/style.css +++ b/web/style.css @@ -409,6 +409,21 @@ dialog::backdrop { width: calc(100% - 10px); } +.comfy-img-preview { + pointer-events: none; + overflow: hidden; + display: flex; + flex-wrap: wrap; + align-content: flex-start; + justify-content: center; +} + +.comfy-img-preview img { + object-fit: contain; + width: var(--comfy-img-preview-width); + height: var(--comfy-img-preview-height); +} + /* Search box */ .litegraph.litesearchbox {