diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml index 42adee9e7..a7760b21e 100644 --- a/.github/workflows/windows_release_cu118_dependencies_2.yml +++ b/.github/workflows/windows_release_cu118_dependencies_2.yml @@ -2,6 +2,13 @@ name: "Windows Release cu118 dependencies 2" on: workflow_dispatch: + inputs: + xformers: + description: 'xformers version' + required: true + type: string + default: "xformers" + # push: # branches: # - master @@ -17,7 +24,7 @@ jobs: - shell: bash run: | - python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip install --no-cache-dir ./temp_wheel_dir/* echo installed basic ls -lah temp_wheel_dir diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..7c7c3e19e --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @comfyanonymous diff --git a/README.md b/README.md index 5e32a74f3..baa8cf8b6 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git | Ctrl + O | Load workflow | | Ctrl + A | Select all nodes | | Ctrl + M | Mute/unmute selected nodes | +| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | | Delete/Backspace | Delete selected nodes | | Ctrl + Delete/Backspace | Delete the current graph | | Space | Move the canvas around when held and moving the cursor | @@ -93,8 +94,8 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` -This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt``` +This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements: +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6``` ### NVIDIA @@ -126,10 +127,10 @@ After this you should have everything installed and can proceed to running Comfy You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version. -1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide. +1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly). 1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. 1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). -1. Launch ComfyUI by running `python main.py`. +1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly. > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 2a16c8101..46fbf0a69 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -13,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import ( ) from ..ldm.modules.attention import SpatialTransformer -from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock +from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.util import exists @@ -57,6 +57,7 @@ class ControlNet(nn.Module): transformer_depth_middle=None, ): super().__init__() + assert use_spatial_transformer == True, "use_spatial_transformer has to be true" if use_spatial_transformer: assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' @@ -200,13 +201,7 @@ class ControlNet(nn.Module): if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( + SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint @@ -259,13 +254,7 @@ class ControlNet(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint diff --git a/comfy/cli_args.py b/comfy/cli_args.py index bef1868b9..ec7d34a55 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -38,8 +38,14 @@ parser.add_argument("--port", type=int, default=8188, help="Set the listen port. parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") +parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") +cm_group = parser.add_mutually_exclusive_group() +cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") +cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") + parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") fp_group = parser.add_mutually_exclusive_group() @@ -80,7 +86,12 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") +parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") + args = parser.parse_args() if args.windows_standalone_build: args.auto_launch = True + +if args.disable_auto_launch: + args.auto_launch = False diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index e2bc3209d..8d04faf71 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -24,8 +24,8 @@ class ClipVisionModel(): return self.model.load_state_dict(sd, strict=False) def encode_image(self, image): - img = torch.clip((255. * image[0]), 0, 255).round().int() - inputs = self.processor(images=[img], return_tensors="pt") + img = torch.clip((255. * image), 0, 255).round().int() + inputs = self.processor(images=img, return_tensors="pt") outputs = self.model(**inputs) return outputs diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 9688cbd52..a9eb9302f 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -148,6 +148,10 @@ vae_conversion_map_attn = [ ("q.", "query."), ("k.", "key."), ("v.", "value."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), ("proj_out.", "proj_attn."), ] diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 2ff10caf1..7eaf6ff62 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -180,7 +180,6 @@ class NoiseScheduleVP: def model_wrapper( model, - sampling_function, noise_schedule, model_type="noise", model_kwargs={}, @@ -295,7 +294,7 @@ def model_wrapper( if t_continuous.reshape((-1,)).shape[0] == 1: t_continuous = t_continuous.expand((x.shape[0])) t_input = get_model_input_time(t_continuous) - output = sampling_function(model, x, t_input, **model_kwargs) + output = model(x, t_input, **model_kwargs) if model_type == "noise": return output elif model_type == "x_start": @@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex else: timesteps = sigmas.clone() - for s in range(timesteps.shape[0]): - timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas)) + alphas_cumprod = model.inner_model.alphas_cumprod - ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod) + for s in range(timesteps.shape[0]): + timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod)) + + ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) if image is not None: img = image * ns.marginal_alpha(timesteps[0]) @@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex img = noise if to_zero: - timesteps[-1] = (1 / len(model.sigmas)) + timesteps[-1] = (1 / len(alphas_cumprod)) device = noise.device - if model.parameterization == "v": - model_type = "v" - else: - model_type = "noise" + + model_type = "noise" model_fn = model_wrapper( - model.inner_model.inner_model.apply_model, - sampling_function, + model.predict_eps_discrete_timestep, ns, model_type=model_type, guidance_type="uncond", diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py index 49ce5ae39..c1a137d9c 100644 --- a/comfy/k_diffusion/external.py +++ b/comfy/k_diffusion/external.py @@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module): t = torch.linspace(t_max, 0, n, device=self.sigmas.device) return sampling.append_zero(self.t_to_sigma(t)) - def sigma_to_t(self, sigma, quantize=None): - quantize = self.quantize if quantize is None else quantize + def sigma_to_discrete_timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def sigma_to_t(self, sigma, quantize=None): + quantize = self.quantize if quantize is None else quantize if quantize: - return dists.abs().argmin(dim=0).view(sigma.shape) + return self.sigma_to_discrete_timestep(sigma) + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] @@ -85,6 +90,12 @@ class DiscreteSchedule(nn.Module): log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp() + def predict_eps_discrete_timestep(self, input, t, **kwargs): + if t.dtype != torch.int64 and t.dtype != torch.int32: + t = t.round() + sigma = self.t_to_sigma(t) + input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) + return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) class DiscreteEpsDDPMDenoiser(DiscreteSchedule): """A wrapper for discrete schedule DDPM models that output eps (the predicted diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 4cc2534f3..beaa623f3 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -3,7 +3,6 @@ import math from scipy import integrate import torch from torch import nn -from torchdiffeq import odeint import torchsde from tqdm.auto import trange, tqdm @@ -131,9 +130,9 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. - eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: + eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) @@ -172,9 +171,9 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. - eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: + eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) @@ -201,9 +200,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. - eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: + eps = torch.randn_like(x) * s_noise x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) @@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o return x -@torch.no_grad() -def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - v = torch.randint_like(x, 2) * 2 - 1 - fevals = 0 - def ode_fn(sigma, x): - nonlocal fevals - with torch.enable_grad(): - x = x[0].detach().requires_grad_() - denoised = model(x, sigma * s_in, **extra_args) - d = to_d(x, sigma, denoised) - fevals += 1 - grad = torch.autograd.grad((d * v).sum(), x)[0] - d_ll = (v * grad).flatten(1).sum(1) - return d.detach(), d_ll - x_min = x, x.new_zeros([x.shape[0]]) - t = x.new_tensor([sigma_min, sigma_max]) - sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') - latent, delta_ll = sol[0][-1], sol[1][-1] - ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) - return ll_prior + delta_ll, {'fevals': fevals} - - class PIDStepSizeController: """A PID controller for ODE adaptive step size control.""" def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): @@ -656,23 +631,78 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl elif solver_type == 'midpoint': x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + if eta: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise old_denoised = denoised h_last = h return x +@torch.no_grad() +def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """DPM-Solver++(3M) SDE.""" + + seed = extra_args.get("seed", None) + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + denoised_1, denoised_2 = None, None + h_1, h_2 = None, None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + h_eta = h * (eta + 1) + + x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + + if h_2 is not None: + r0 = h_1 / h + r1 = h_2 / h + d1_0 = (denoised - denoised_1) / r0 + d1_1 = (denoised_1 - denoised_2) / r1 + d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1) + d2 = (d1_0 - d1_1) / (r0 + r1) + phi_2 = h_eta.neg().expm1() / h_eta + 1 + phi_3 = phi_2 / h_eta - 0.5 + x = x + phi_2 * d1 - phi_3 * d2 + elif h_1 is not None: + r = h_1 / h + d = (denoised - denoised_1) / r + phi_2 = h_eta.neg().expm1() / h_eta + 1 + x = x + phi_2 * d + + if eta: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise + + denoised_1, denoised_2 = denoised, denoised_1 + h_1, h_2 = h, h_1 + return x + +@torch.no_grad() +def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler + return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) - @torch.no_grad() def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) - diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index 108fce1cf..139c8e01e 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -14,6 +14,7 @@ class DDIMSampler(object): self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule self.device = device + self.parameterization = kwargs.get("parameterization", "eps") def register_buffer(self, name, attr): if type(attr) == torch.Tensor: @@ -261,7 +262,7 @@ class DDIMSampler(object): b, *_, device = *x.shape, x.device if denoise_function is not None: - model_output = denoise_function(self.model.apply_model, x, t, **extra_args) + model_output = denoise_function(x, t, **extra_args) elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: model_output = self.model.apply_model(x, t, c) else: @@ -289,13 +290,13 @@ class DDIMSampler(object): model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) - if self.model.parameterization == "v": + if self.parameterization == "v": e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x else: e_t = model_output if score_corrector is not None: - assert self.model.parameterization == "eps", 'not implemented' + assert self.parameterization == "eps", 'not implemented' e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas @@ -309,7 +310,7 @@ class DDIMSampler(object): sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) # current prediction for x_0 - if self.model.parameterization != "v": + if self.parameterization != "v": pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() else: pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2284bcbdb..573cea6ac 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -36,7 +36,7 @@ def uniq(arr): def default(val, d): if exists(val): return val - return d() if isfunction(d) else d + return d def max_neg_value(t): @@ -52,9 +52,9 @@ def init_(tensor): # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out, dtype=None): + def __init__(self, dim_in, dim_out, dtype=None, device=None): super().__init__() - self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype) + self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -62,19 +62,19 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( - comfy.ops.Linear(dim, inner_dim, dtype=dtype), + comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device), nn.GELU() - ) if not glu else GEGLU(dim, inner_dim, dtype=dtype) + ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device) self.net = nn.Sequential( project_in, nn.Dropout(dropout), - comfy.ops.Linear(inner_dim, dim_out, dtype=dtype) + comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device) ) def forward(self, x): @@ -90,8 +90,8 @@ def zero_module(module): return module -def Normalize(in_channels, dtype=None): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype) +def Normalize(in_channels, dtype=None, device=None): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) class SpatialSelfAttention(nn.Module): @@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module): class CrossAttentionBirchSan(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_out = nn.Sequential( - comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), + comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module): class CrossAttentionDoggettx(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_out = nn.Sequential( - comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), + comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module): return self.to_out(r2) class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -351,12 +351,12 @@ class CrossAttention(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_out = nn.Sequential( - comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), + comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -399,7 +399,7 @@ class CrossAttention(nn.Module): class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None): super().__init__() print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " f"{heads} heads.") @@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout)) + self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): @@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module): return self.to_out(out) class CrossAttentionPytorch(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module): self.heads = heads self.dim_head = dim_head - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout)) + self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): @@ -508,17 +508,17 @@ else: class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False, dtype=None): + disable_self_attn=False, dtype=None, device=None): super().__init__() self.disable_self_attn = disable_self_attn self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype) + context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim, dtype=dtype) - self.norm2 = nn.LayerNorm(dim, dtype=dtype) - self.norm3 = nn.LayerNorm(dim, dtype=dtype) + heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) + self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) + self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head @@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, - use_checkpoint=True, dtype=None): + use_checkpoint=True, dtype=None, device=None): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] * depth self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = Normalize(in_channels, dtype=dtype) + self.norm = Normalize(in_channels, dtype=dtype, device=device) if not use_linear: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, - padding=0, dtype=dtype) + padding=0, dtype=dtype, device=device) else: - self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype) + self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype) + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device) for d in range(depth)] ) if not use_linear: self.proj_out = nn.Conv2d(inner_dim,in_channels, kernel_size=1, stride=1, - padding=0, dtype=dtype) + padding=0, dtype=dtype, device=device) else: - self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype) + self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.use_linear = use_linear def forward(self, x, context=None, transformer_options={}): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 69ab21cdc..b596408d3 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -8,6 +8,7 @@ from typing import Optional, Any from ..attention import MemoryEfficientCrossAttention from comfy import model_management +import comfy.ops if model_management.xformers_enabled_vae(): import xformers @@ -48,7 +49,7 @@ class Upsample(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, + self.conv = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, @@ -67,7 +68,7 @@ class Downsample(nn.Module): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, + self.conv = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, @@ -95,30 +96,30 @@ class ResnetBlock(nn.Module): self.swish = torch.nn.SiLU(inplace=True) self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, + self.conv1 = comfy.ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, + self.temb_proj = comfy.ops.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) - self.conv2 = torch.nn.Conv2d(out_channels, + self.conv2 = comfy.ops.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, + self.conv_shortcut = comfy.ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, + self.nin_shortcut = comfy.ops.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, @@ -188,22 +189,22 @@ class AttnBlock(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, + self.q = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, + self.k = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, + self.v = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, + self.proj_out = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, @@ -243,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, + self.q = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, + self.k = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, + self.v = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, + self.proj_out = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, @@ -302,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, + self.q = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, + self.k = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, + self.v = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, + self.proj_out = comfy.ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, @@ -399,14 +400,14 @@ class Model(nn.Module): # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, + comfy.ops.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, + comfy.ops.Linear(self.temb_ch, self.temb_ch), ]) # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, + self.conv_in = comfy.ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -475,7 +476,7 @@ class Model(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, + self.conv_out = comfy.ops.Conv2d(block_in, out_ch, kernel_size=3, stride=1, @@ -548,7 +549,7 @@ class Encoder(nn.Module): self.in_channels = in_channels # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, + self.conv_in = comfy.ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -593,7 +594,7 @@ class Encoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, + self.conv_out = comfy.ops.Conv2d(block_in, 2*z_channels if double_z else z_channels, kernel_size=3, stride=1, @@ -653,7 +654,7 @@ class Decoder(nn.Module): self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, + self.conv_in = comfy.ops.Conv2d(z_channels, block_in, kernel_size=3, stride=1, @@ -695,7 +696,7 @@ class Decoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, + self.conv_out = comfy.ops.Conv2d(block_in, out_ch, kernel_size=3, stride=1, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 92f2438ef..90c153465 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -19,45 +19,6 @@ from ..attention import SpatialTransformer from comfy.ldm.util import exists -# dummy replace -def convert_module_to_f16(x): - pass - -def convert_module_to_f32(x): - pass - - -## go -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x): - b, c, *_spatial = x.shape - x = x.reshape(b, c, -1) # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. @@ -111,14 +72,14 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype) + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device) def forward(self, x, output_shape=None): assert x.shape[1] == self.channels @@ -138,19 +99,6 @@ class Upsample(nn.Module): x = self.conv(x) return x -class TransposedUpsample(nn.Module): - 'Learned 2x upsampling without padding' - def __init__(self, channels, out_channels=None, ks=5): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) - - def forward(self,x): - return self.up(x) - - class Downsample(nn.Module): """ A downsampling layer with an optional convolution. @@ -160,7 +108,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -169,7 +117,7 @@ class Downsample(nn.Module): stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device ) else: assert self.channels == self.out_channels @@ -208,7 +156,8 @@ class ResBlock(TimestepBlock): use_checkpoint=False, up=False, down=False, - dtype=None + dtype=None, + device=None, ): super().__init__() self.channels = channels @@ -220,19 +169,19 @@ class ResBlock(TimestepBlock): self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( - nn.GroupNorm(32, channels, dtype=dtype), + nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype), + conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), ) self.updown = up or down if up: - self.h_upd = Upsample(channels, False, dims, dtype=dtype) - self.x_upd = Upsample(channels, False, dims, dtype=dtype) + self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device) + self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device) elif down: - self.h_upd = Downsample(channels, False, dims, dtype=dtype) - self.x_upd = Downsample(channels, False, dims, dtype=dtype) + self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device) + self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device) else: self.h_upd = self.x_upd = nn.Identity() @@ -240,15 +189,15 @@ class ResBlock(TimestepBlock): nn.SiLU(), linear( emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device ), ) self.out_layers = nn.Sequential( - nn.GroupNorm(32, self.out_channels, dtype=dtype), + nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype) + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) ), ) @@ -256,10 +205,10 @@ class ResBlock(TimestepBlock): self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1, dtype=dtype + dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype) + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) def forward(self, x, emb): """ @@ -295,142 +244,6 @@ class ResBlock(TimestepBlock): h = self.out_layers(h) return self.skip_connection(x) + h - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - #return pt_checkpoint(self._forward, x) # pytorch - - def _forward(self, x): - b, c, *spatial = x.shape - x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) - - -def count_flops_attn(model, _x, y): - """ - A counter for the `thop` package to count the operations in an - attention operation. - Meant to be used like: - macs, params = thop.profile( - model, - inputs=(inputs, timestamps), - custom_ops={QKVAttention: QKVAttention.count_flops}, - ) - """ - b, c, *spatial = y[0].shape - num_spatial = int(np.prod(spatial)) - # We perform two matmuls with the same number of ops. - # The first computes the weight matrix, the second computes - # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c - model.total_ops += th.DoubleTensor([matmul_ops]) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - class Timestep(nn.Module): def __init__(self, dim): super().__init__() @@ -503,8 +316,10 @@ class UNetModel(nn.Module): use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + device=None, ): super().__init__() + assert use_spatial_transformer == True, "use_spatial_transformer has to be true" if use_spatial_transformer: assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' @@ -564,9 +379,9 @@ class UNetModel(nn.Module): time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim, dtype=self.dtype), + linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(), - linear(time_embed_dim, time_embed_dim, dtype=self.dtype), + linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), ) if self.num_classes is not None: @@ -579,9 +394,9 @@ class UNetModel(nn.Module): assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( - linear(adm_in_channels, time_embed_dim, dtype=self.dtype), + linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(), - linear(time_embed_dim, time_embed_dim, dtype=self.dtype), + linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), ) ) else: @@ -590,7 +405,7 @@ class UNetModel(nn.Module): self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype) + conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) ] ) @@ -609,7 +424,8 @@ class UNetModel(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype + dtype=self.dtype, + device=device, ) ] ch = mult * model_channels @@ -628,17 +444,10 @@ class UNetModel(nn.Module): disabled_sa = False if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( + layers.append(SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -657,11 +466,12 @@ class UNetModel(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, - dtype=self.dtype + dtype=self.dtype, + device=device, ) if resblock_updown else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype + ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device ) ) ) @@ -686,18 +496,13 @@ class UNetModel(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype + dtype=self.dtype, + device=device, ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device ), ResBlock( ch, @@ -706,7 +511,8 @@ class UNetModel(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype + dtype=self.dtype, + device=device, ), ) self._feature_size += ch @@ -724,7 +530,8 @@ class UNetModel(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, - dtype=self.dtype + dtype=self.dtype, + device=device, ) ] ch = model_channels * mult @@ -744,16 +551,10 @@ class UNetModel(nn.Module): if not exists(num_attention_blocks) or i < num_attention_blocks[level]: layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( + SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device ) ) if level and i == self.num_res_blocks[level]: @@ -768,43 +569,28 @@ class UNetModel(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, - dtype=self.dtype + dtype=self.dtype, + device=device, ) if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype) + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype), + nn.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype), - conv_nd(dims, model_channels, n_embed, 1), + nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): """ Apply the model to an input batch. diff --git a/comfy/model_base.py b/comfy/model_base.py index 9197dc4b9..ad661ec7d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,27 +4,27 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import numpy as np +from enum import Enum from . import utils +class ModelType(Enum): + EPS = 1 + V_PREDICTION = 2 + class BaseModel(torch.nn.Module): - def __init__(self, model_config, v_prediction=False): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__() unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) - self.diffusion_model = UNetModel(**unet_config) - self.v_prediction = v_prediction - if self.v_prediction: - self.parameterization = "v" - else: - self.parameterization = "eps" - + self.diffusion_model = UNetModel(**unet_config, device=device) + self.model_type = model_type self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 - print("v_prediction", v_prediction) + print("model_type", model_type.name) print("adm", self.adm_channels) def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, @@ -99,53 +99,58 @@ class BaseModel(torch.nn.Module): if self.get_dtype() == torch.float16: clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + + if self.model_type == ModelType.V_PREDICTION: + unet_state_dict["v_pred"] = torch.tensor([]) + return {**unet_state_dict, **vae_state_dict, **clip_state_dict} +def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): + adm_inputs = [] + weights = [] + noise_aug = [] + for unclip_cond in unclip_conditioning: + for adm_cond in unclip_cond["clip_vision_output"].image_embeds: + weight = unclip_cond["strength"] + noise_augment = unclip_cond["noise_augmentation"] + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + noise_aug.append(noise_augment) + adm_inputs.append(adm_out) + + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + noise_augment = noise_augment_merge + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) + + return adm_out class SD21UNCLIP(BaseModel): - def __init__(self, model_config, noise_aug_config, v_prediction=True): - super().__init__(model_config, v_prediction) + def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None): + super().__init__(model_config, model_type, device=device) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) def encode_adm(self, **kwargs): unclip_conditioning = kwargs.get("unclip_conditioning", None) device = kwargs["device"] - - if unclip_conditioning is not None: - adm_inputs = [] - weights = [] - noise_aug = [] - for unclip_cond in unclip_conditioning: - adm_cond = unclip_cond["clip_vision_output"].image_embeds - weight = unclip_cond["strength"] - noise_augment = unclip_cond["noise_augmentation"] - noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight - weights.append(weight) - noise_aug.append(noise_augment) - adm_inputs.append(adm_out) - - if len(noise_aug) > 1: - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: add a way to control this - noise_augment = 0.05 - noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) + if unclip_conditioning is None: + return torch.zeros((1, self.adm_channels)) else: - adm_out = torch.zeros((1, self.adm_channels)) + return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) - return adm_out class SDInpaint(BaseModel): - def __init__(self, model_config, v_prediction=False): - super().__init__(model_config, v_prediction) + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) self.concat_keys = ("mask", "masked_image") class SDXLRefiner(BaseModel): - def __init__(self, model_config, v_prediction=False): - super().__init__(model_config, v_prediction) + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) self.embedder = Timestep(256) def encode_adm(self, **kwargs): @@ -160,7 +165,6 @@ class SDXLRefiner(BaseModel): else: aesthetic_score = kwargs.get("aesthetic_score", 6) - print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score) out = [] out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([width]))) @@ -171,8 +175,8 @@ class SDXLRefiner(BaseModel): return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SDXL(BaseModel): - def __init__(self, model_config, v_prediction=False): - super().__init__(model_config, v_prediction) + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) self.embedder = Timestep(256) def encode_adm(self, **kwargs): @@ -184,7 +188,6 @@ class SDXL(BaseModel): target_width = kwargs.get("target_width", width) target_height = kwargs.get("target_height", height) - print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height) out = [] out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([width]))) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index cf764e0b7..49ee9ea70 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -113,8 +113,63 @@ def model_config_from_unet_config(unet_config): if model_config.matches(unet_config): return model_config(unet_config) + print("no match", unet_config) return None def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) return model_config_from_unet_config(unet_config) + + +def model_config_from_diffusers_unet(state_dict, use_fp16): + match = {} + match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] + match["model_channels"] = state_dict["conv_in.weight"].shape[0] + match["in_channels"] = state_dict["conv_in.weight"].shape[1] + match["adm_in_channels"] = None + if "class_embedding.linear_1.weight" in state_dict: + match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1] + elif "add_embedding.linear_1.weight" in state_dict: + match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] + + SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + + SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} + + SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] + + for unet_config in supported_models: + matches = True + for k in match: + if match[k] != unet_config[k]: + matches = False + break + if matches: + return model_config_from_unet_config(unet_config) + return None diff --git a/comfy/model_management.py b/comfy/model_management.py index e148408b8..4dd15b41c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -49,6 +49,7 @@ except: try: if torch.backends.mps.is_available(): cpu_state = CPUState.MPS + import torch.mps except: pass @@ -204,7 +205,11 @@ print(f"Set vram state to: {vram_state.name}") def get_torch_device_name(device): if hasattr(device, 'type'): if device.type == "cuda": - return "{} {}".format(device, torch.cuda.get_device_name(device)) + try: + allocator_backend = torch.cuda.get_allocator_backend() + except: + allocator_backend = "" + return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) else: return "{}".format(device.type) else: @@ -233,10 +238,9 @@ def unload_model(): accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) model_accelerated = False - + current_loaded_model.unpatch_model() current_loaded_model.model.to(current_loaded_model.offload_device) current_loaded_model.model_patches_to(current_loaded_model.offload_device) - current_loaded_model.unpatch_model() current_loaded_model = None if vram_state != VRAMState.HIGH_VRAM: soft_empty_cache() @@ -258,15 +262,11 @@ def load_model_gpu(model): if model is current_loaded_model: return unload_model() - try: - real_model = model.patch_model() - except Exception as e: - model.unpatch_model() - raise e torch_dev = model.load_device model.model_patches_to(torch_dev) model.model_patches_to(model.model_dtype()) + current_loaded_model = model if is_device_cpu(torch_dev): vram_set_state = VRAMState.DISABLED @@ -280,21 +280,33 @@ def load_model_gpu(model): if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM - current_loaded_model = model - + real_model = model.model + patch_model_to = None if vram_set_state == VRAMState.DISABLED: pass elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False - real_model.to(torch_dev) - else: - if vram_set_state == VRAMState.NO_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_set_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) + patch_model_to = torch_dev + try: + real_model = model.patch_model(device_to=patch_model_to) + except Exception as e: + model.unpatch_model() + unload_model() + raise e + + if patch_model_to is not None: + real_model.to(torch_dev) + + if vram_set_state == VRAMState.NO_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) model_accelerated = True + elif vram_set_state == VRAMState.LOW_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) + accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) + model_accelerated = True + return current_loaded_model def load_controlnet_gpu(control_models): @@ -352,6 +364,7 @@ def text_encoder_device(): if args.gpu_only: return get_torch_device() elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + #NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough. return get_torch_device() else: @@ -522,7 +535,7 @@ def should_use_fp16(device=None, model_params=0): return False #FP16 is just broken on these cards - nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"] + nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"] for x in nvidia_16_series: if x in props.name: return False diff --git a/comfy/samplers.py b/comfy/samplers.py index 81d1facd8..28cd46667 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,7 @@ from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math +from comfy import model_base def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) return abs(a*b) // math.gcd(a, b) @@ -16,6 +17,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 + if 'timestep_start' in cond[1]: + timestep_start = cond[1]['timestep_start'] + if timestep_in[0] > timestep_start: + return None + if 'timestep_end' in cond[1]: + timestep_end = cond[1]['timestep_end'] + if timestep_in[0] < timestep_end: + return None if 'area' in cond[1]: area = cond[1]['area'] if 'strength' in cond[1]: @@ -180,12 +189,13 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con continue to_run += [(p, COND)] - for x in uncond: - p = get_area_and_mult(x, x_in, cond_concat_in, timestep) - if p is None: - continue + if uncond is not None: + for x in uncond: + p = get_area_and_mult(x, x_in, cond_concat_in, timestep) + if p is None: + continue - to_run += [(p, UNCOND)] + to_run += [(p, UNCOND)] while len(to_run) > 0: first = to_run[0] @@ -247,7 +257,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c['transformer_options'] = transformer_options - output = model_function(input_x, timestep_, **c).chunk(batch_chunks) + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + else: + output = model_function(input_x, timestep_, **c).chunk(batch_chunks) del input_x model_management.throw_exception_if_processing_interrupted() @@ -270,6 +283,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() + if math.isclose(cond_scale, 1.0): + uncond = None + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) if "sampler_cfg_function" in model_options: args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} @@ -331,6 +347,17 @@ def ddim_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) +def sgm_scheduler(model, steps): + sigs = [] + timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int) + for x in range(len(timesteps)): + ts = timesteps[x] + if ts > 999: + ts = 999 + sigs.append(model.t_to_sigma(torch.tensor(ts))) + sigs += [0.0] + return torch.FloatTensor(sigs) + def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space @@ -424,6 +451,35 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] +def calculate_start_end_timesteps(model, conds): + for t in range(len(conds)): + x = conds[t] + + timestep_start = None + timestep_end = None + if 'start_percent' in x[1]: + timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0))) + if 'end_percent' in x[1]: + timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0))) + + if (timestep_start is not None) or (timestep_end is not None): + n = x[1].copy() + if (timestep_start is not None): + n['timestep_start'] = timestep_start + if (timestep_end is not None): + n['timestep_end'] = timestep_end + conds[t] = [x[0], n] + +def pre_run_control(model, conds): + for t in range(len(conds)): + x = conds[t] + + timestep_start = None + timestep_end = None + percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) + if 'control' in x[1]: + x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function) + def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] cond_other = [] @@ -480,19 +536,19 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): class KSampler: - SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model self.model_denoise = CFGNoisePredictor(self.model) - if self.model.parameterization == "v": + if self.model.model_type == model_base.ModelType.V_PREDICTION: self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) else: self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True) - self.model_wrap.parameterization = self.model.parameterization + self.model_k = KSamplerX0Inpaint(self.model_wrap) self.device = device if scheduler not in self.SCHEDULERS: @@ -525,6 +581,8 @@ class KSampler: sigmas = simple_scheduler(self.model_wrap, steps) elif self.scheduler == "ddim_uniform": sigmas = ddim_scheduler(self.model_wrap, steps) + elif self.scheduler == "sgm_uniform": + sigmas = sgm_scheduler(self.model_wrap, steps) else: print("error invalid scheduler", self.scheduler) @@ -567,13 +625,18 @@ class KSampler: resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + calculate_start_end_timesteps(self.model_wrap, negative) + calculate_start_end_timesteps(self.model_wrap, positive) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) for c in negative: create_cond_with_same_area_if_none(positive, c) - apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + pre_run_control(self.model_wrap, negative + positive) + + apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if self.model.is_adm(): @@ -614,7 +677,7 @@ class KSampler: elif self.sampler == "ddim": timesteps = [] for s in range(sigmas.shape[0]): - timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s])) + timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s])) noise_mask = None if denoise_mask is not None: noise_mask = 1.0 - denoise_mask @@ -638,7 +701,7 @@ class KSampler: x_T=z_enc, x0=latent_image, img_callback=ddim_callback, - denoise_function=sampling_function, + denoise_function=self.model_wrap.predict_eps_discrete_timestep, extra_args=extra_args, mask=noise_mask, to_zero=sigmas[-1]==0, diff --git a/comfy/sd.py b/comfy/sd.py index 125b15b77..bff9ee141 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -70,13 +70,27 @@ def load_lora(lora, to_load): alpha = lora[alpha_name].item() loaded_keys.add(alpha_name) - A_name = "{}.lora_up.weight".format(x) - B_name = "{}.lora_down.weight".format(x) - mid_name = "{}.lora_mid.weight".format(x) + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None - if A_name in lora.keys(): + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name ="{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: mid = None - if mid_name in lora.keys(): + if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) @@ -170,20 +184,31 @@ def model_lora_keys_clip(model, key_map={}): if k in sdk: lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base key_map[lora_key] = k clip_l_present = True + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: if clip_l_present: lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + key_map[lora_key] = k + lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k else: lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner - key_map[lora_key] = k + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + key_map[lora_key] = k return key_map @@ -198,10 +223,26 @@ def model_lora_keys_unet(model, key_map={}): diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config) for k in diffusers_keys: if k.endswith(".weight"): + unet_key = "diffusion_model.{}".format(diffusers_keys[k]) key_lora = k[:-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) + key_map["lora_unet_{}".format(key_lora)] = unet_key + + diffusers_lora_prefix = ["", "unet."] + for p in diffusers_lora_prefix: + diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) + if diffusers_lora_key.endswith(".to_out.0"): + diffusers_lora_key = diffusers_lora_key[:-2] + key_map[diffusers_lora_key] = unet_key return key_map +def set_attr(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + setattr(obj, attrs[-1], torch.nn.Parameter(value)) + del prev + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0): self.size = size @@ -330,7 +371,7 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self): + def patch_model(self, device_to=None): model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: @@ -340,10 +381,14 @@ class ModelPatcher: weight = model_sd[key] if key not in self.backup: - self.backup[key] = weight.clone() + self.backup[key] = weight.to(self.offload_device) - temp_weight = weight.to(torch.float32, copy=True) - weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if device_to is not None: + temp_weight = weight.float().to(device_to, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + set_attr(self.model, key, out_weight) del temp_weight return self.model @@ -367,15 +412,19 @@ class ModelPatcher: else: weight += alpha * w1.type(weight.dtype).to(weight.device) elif len(v) == 4: #lora/locon - mat1 = v[0] - mat2 = v[1] + mat1 = v[0].float().to(weight.device) + mat2 = v[1].float().to(weight.device) if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it - final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) - weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + mat3 = v[3].float().to(weight.device) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: + weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) elif len(v) == 8: #lokr w1 = v[0] w2 = v[1] @@ -389,20 +438,27 @@ class ModelPatcher: if w1 is None: dim = w1_b.shape[0] w1 = torch.mm(w1_a.float(), w1_b.float()) + else: + w1 = w1.float().to(weight.device) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(w2_a.float(), w2_b.float()) + w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float()) + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) + else: + w2 = w2.float().to(weight.device) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) if v[2] is not None and dim is not None: alpha *= v[2] / dim - weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device) + try: + weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) else: #loha w1a = v[0] w1b = v[1] @@ -413,21 +469,24 @@ class ModelPatcher: if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float()) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float()) + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) else: - m1 = torch.mm(w1a.float(), w1b.float()) - m2 = torch.mm(w2a.float(), w2b.float()) + m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) + m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) + + try: + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) return weight def unpatch_model(self): - model_sd = self.model_state_dict() keys = list(self.backup.keys()) + for k in keys: - model_sd[k][:] = self.backup[k] - del self.backup[k] + set_attr(self.model, k, self.backup[k]) self.backup = {} @@ -647,16 +706,57 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): else: return torch.cat([tensor] * batched_number, dim=0) -class ControlNet: - def __init__(self, control_model, global_average_pooling=False, device=None): - self.control_model = control_model +class ControlBase: + def __init__(self, device=None): self.cond_hint_original = None self.cond_hint = None self.strength = 1.0 + self.timestep_percent_range = (1.0, 0.0) + self.timestep_range = None + if device is None: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): + self.cond_hint_original = cond_hint + self.strength = strength + self.timestep_percent_range = timestep_percent_range + return self + + def pre_run(self, model, percent_to_timestep_function): + self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) + if self.previous_controlnet is not None: + self.previous_controlnet.pre_run(model, percent_to_timestep_function) + + def set_previous_controlnet(self, controlnet): + self.previous_controlnet = controlnet + return self + + def cleanup(self): + if self.previous_controlnet is not None: + self.previous_controlnet.cleanup() + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.timestep_range = None + + def get_models(self): + out = [] + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_models() + return out + + def copy_to(self, c): + c.cond_hint_original = self.cond_hint_original + c.strength = self.strength + c.timestep_percent_range = self.timestep_percent_range + +class ControlNet(ControlBase): + def __init__(self, control_model, global_average_pooling=False, device=None): + super().__init__(device) + self.control_model = control_model self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond, batched_number): @@ -664,6 +764,13 @@ class ControlNet: if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: @@ -711,37 +818,64 @@ class ControlNet: out['input'] = control_prev['input'] return out - def set_cond_hint(self, cond_hint, strength=1.0): - self.cond_hint_original = cond_hint - self.strength = strength - return self - - def set_previous_controlnet(self, controlnet): - self.previous_controlnet = controlnet - return self - - def cleanup(self): - if self.previous_controlnet is not None: - self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - def copy(self): c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) - c.cond_hint_original = self.cond_hint_original - c.strength = self.strength + self.copy_to(c) return c def get_models(self): - out = [] - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_models() + out = super().get_models() out.append(self.control_model) return out + def load_controlnet(ckpt_path, model=None): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) + + controlnet_config = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format + use_fp16 = model_management.should_use_fp16() + controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config + diffusers_keys = utils.unet_to_diffusers(controlnet_config) + diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" + diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + k_in = "controlnet_down_blocks.{}{}".format(count, s) + k_out = "zero_convs.{}.0{}".format(count, s) + if k_in not in controlnet_data: + loop = False + break + diffusers_keys[k_in] = k_out + count += 1 + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + if count == 0: + k_in = "controlnet_cond_embedding.conv_in{}".format(s) + else: + k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) + k_out = "input_hint_block.{}{}".format(count * 2, s) + if k_in not in controlnet_data: + k_in = "controlnet_cond_embedding.conv_out{}".format(s) + loop = False + diffusers_keys[k_in] = k_out + count += 1 + + new_sd = {} + for k in diffusers_keys: + if k in controlnet_data: + new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + + controlnet_data = new_sd + pth_key = 'control_model.zero_convs.0.0.weight' pth = False key = 'zero_convs.0.0.weight' @@ -757,11 +891,11 @@ def load_controlnet(ckpt_path, model=None): print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) return net - use_fp16 = model_management.should_use_fp16() - - controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config + if controlnet_config is None: + use_fp16 = model_management.should_use_fp16() + controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config controlnet_config.pop("out_channels") - controlnet_config["hint_channels"] = 3 + controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = cldm.ControlNet(**controlnet_config) if pth: @@ -799,24 +933,25 @@ def load_controlnet(ckpt_path, model=None): control = ControlNet(control_model, global_average_pooling=global_average_pooling) return control -class T2IAdapter: +class T2IAdapter(ControlBase): def __init__(self, t2i_model, channels_in, device=None): + super().__init__(device) self.t2i_model = t2i_model self.channels_in = channels_in - self.strength = 1.0 - if device is None: - device = model_management.get_torch_device() - self.device = device - self.previous_controlnet = None self.control_input = None - self.cond_hint_original = None - self.cond_hint = None def get_control(self, x_noisy, t, cond, batched_number): control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint @@ -861,33 +996,11 @@ class T2IAdapter: out['output'] = control_prev['output'] return out - def set_cond_hint(self, cond_hint, strength=1.0): - self.cond_hint_original = cond_hint - self.strength = strength - return self - - def set_previous_controlnet(self, controlnet): - self.previous_controlnet = controlnet - return self - def copy(self): c = T2IAdapter(self.t2i_model, self.channels_in) - c.cond_hint_original = self.cond_hint_original - c.strength = self.strength + self.copy_to(c) return c - def cleanup(self): - if self.previous_controlnet is not None: - self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - - def get_models(self): - out = [] - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_models() - return out def load_t2i_adapter(t2i_data): keys = t2i_data.keys() @@ -997,11 +1110,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if "noise_aug_config" in model_config_params: noise_aug_config = model_config_params["noise_aug_config"] - v_prediction = False + model_type = model_base.ModelType.EPS if "parameterization" in model_config_params: if model_config_params["parameterization"] == "v": - v_prediction = True + model_type = model_base.ModelType.V_PREDICTION clip = None vae = None @@ -1021,11 +1134,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) if config['model']["target"].endswith("LatentInpaintDiffusion"): - model = model_base.SDInpaint(model_config, v_prediction=v_prediction) + model = model_base.SDInpaint(model_config, model_type=model_type) elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): - model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction) + model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) else: - model = model_base.BaseModel(model_config, v_prediction=v_prediction) + model = model_base.BaseModel(model_config, model_type=model_type) if fp16: model = model.half() @@ -1087,8 +1200,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd, "model.diffusion_model.") - model = model.to(offload_device) + model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device) model.load_model_weights(sd, "model.diffusion_model.") if output_vae: @@ -1117,66 +1229,24 @@ def load_unet(unet_path): #load unet in diffusers format parameters = calculate_parameters(sd, "") fp16 = model_management.should_use_fp16(model_params=parameters) - match = {} - match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] - match["model_channels"] = sd["conv_in.weight"].shape[0] - match["in_channels"] = sd["conv_in.weight"].shape[1] - match["adm_in_channels"] = None - if "class_embedding.linear_1.weight" in sd: - match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1] + model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) + if model_config is None: + print("ERROR UNSUPPORTED UNET", unet_path) + return None - SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + diffusers_keys = utils.unet_to_diffusers(model_config.unet_config) - SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} - - SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} - - SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} - - SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} - - SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} - - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] - print("match", match) - for unet_config in supported_models: - matches = True - for k in match: - if match[k] != unet_config[k]: - matches = False - break - if matches: - diffusers_keys = utils.unet_to_diffusers(unet_config) - new_sd = {} - for k in diffusers_keys: - if k in sd: - new_sd[diffusers_keys[k]] = sd.pop(k) - else: - print(diffusers_keys[k], k) - offload_device = model_management.unet_offload_device() - model_config = model_detection.model_config_from_unet_config(unet_config) - model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") - return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + print(diffusers_keys[k], k) + offload_device = model_management.unet_offload_device() + model = model_config.get_model(new_sd, "") + model = model.to(offload_device) + model.load_model_weights(new_sd, "") + return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) def save_checkpoint(output_path, model, clip, vae, metadata=None): try: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d504bf77d..feca41880 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -91,13 +91,15 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_up_textual_embeddings(self, tokens, current_embeds): out_tokens = [] - next_new_token = token_dict_size = current_embeds.weight.shape[0] + next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1 embedding_weights = [] for x in tokens: tokens_temp = [] for y in x: if isinstance(y, int): + if y == token_dict_size: #EOS token + y = -1 tokens_temp += [y] else: if y.shape[0] == current_embeds.weight.shape[1]: @@ -110,15 +112,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_temp += [self.empty_tokens[0][-1]] out_tokens += [tokens_temp] + n = token_dict_size if len(embedding_weights) > 0: - new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) - new_embedding.weight[:token_dict_size] = current_embeds.weight[:] - n = token_dict_size + new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) + new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1] for x in embedding_weights: new_embedding.weight[n] = x n += 1 + new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding self.transformer.set_input_embeddings(new_embedding) - return out_tokens + + processed_tokens = [] + for x in out_tokens: + processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one + + return processed_tokens def forward(self, tokens): backup_embeds = self.transformer.get_input_embeddings() diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b7fdfe9fe..95fc8f3f5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE): latent_format = latent_formats.SD15 - def v_prediction(self, state_dict, prefix=""): + def model_type(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) out = state_dict[k] if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. - return True - return False + return model_base.ModelType.V_PREDICTION + return model_base.ModelType.EPS def process_clip_state_dict(self, state_dict): state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) @@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE): latent_format = latent_formats.SDXL - def get_model(self, state_dict, prefix=""): - return model_base.SDXLRefiner(self) + def get_model(self, state_dict, prefix="", device=None): + return model_base.SDXLRefiner(self, device=device) def process_clip_state_dict(self, state_dict): keys_to_replace = {} @@ -126,7 +126,8 @@ class SDXLRefiner(supported_models_base.BASE): def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") - state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") + if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: + state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") replace_prefix["clip_g"] = "conditioner.embedders.0.model" state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g @@ -145,8 +146,14 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL - def get_model(self, state_dict, prefix=""): - return model_base.SDXL(self) + def model_type(self, state_dict, prefix=""): + if "v_pred" in state_dict: + return model_base.ModelType.V_PREDICTION + else: + return model_base.ModelType.EPS + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device) def process_clip_state_dict(self, state_dict): keys_to_replace = {} @@ -165,7 +172,8 @@ class SDXL(supported_models_base.BASE): replace_prefix = {} keys_to_replace = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") - state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") + if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: + state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") for k in state_dict: if k.startswith("clip_l"): state_dict_g[k] = state_dict[k] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 86dc67068..d0088bbd5 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -41,8 +41,8 @@ class BASE: return False return True - def v_prediction(self, state_dict, prefix=""): - return False + def model_type(self, state_dict, prefix=""): + return model_base.ModelType.EPS def inpaint_model(self): return self.unet_config["in_channels"] > 4 @@ -53,13 +53,13 @@ class BASE: for x in self.unet_extra_config: self.unet_config[x] = self.unet_extra_config[x] - def get_model(self, state_dict, prefix=""): + def get_model(self, state_dict, prefix="", device=None): if self.inpaint_model(): - return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix)) + return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device) elif self.noise_aug_config is not None: - return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix)) + return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device) else: - return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix)) + return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device) def process_clip_state_dict(self, state_dict): return state_dict diff --git a/comfy/utils.py b/comfy/utils.py index 956ac1773..3bbe4f9a9 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -4,18 +4,20 @@ import struct import comfy.checkpoint_pickle import safetensors.torch -def load_torch_file(ckpt, safe_load=False): +def load_torch_file(ckpt, safe_load=False, device=None): + if device is None: + device = torch.device("cpu") if ckpt.lower().endswith(".safetensors"): - sd = safetensors.torch.load_file(ckpt, device="cpu") + sd = safetensors.torch.load_file(ckpt, device=device.type) else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: - pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) + pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: - pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle) + pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: @@ -118,20 +120,24 @@ UNET_MAP_RESNET = { } UNET_MAP_BASIC = { - "label_emb.0.0.weight": "class_embedding.linear_1.weight", - "label_emb.0.0.bias": "class_embedding.linear_1.bias", - "label_emb.0.2.weight": "class_embedding.linear_2.weight", - "label_emb.0.2.bias": "class_embedding.linear_2.bias", - "input_blocks.0.0.weight": "conv_in.weight", - "input_blocks.0.0.bias": "conv_in.bias", - "out.0.weight": "conv_norm_out.weight", - "out.0.bias": "conv_norm_out.bias", - "out.2.weight": "conv_out.weight", - "out.2.bias": "conv_out.bias", - "time_embed.0.weight": "time_embedding.linear_1.weight", - "time_embed.0.bias": "time_embedding.linear_1.bias", - "time_embed.2.weight": "time_embedding.linear_2.weight", - "time_embed.2.bias": "time_embedding.linear_2.bias" + ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), + ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias") } def unet_to_diffusers(unet_config): @@ -206,7 +212,7 @@ def unet_to_diffusers(unet_config): n += 1 for k in UNET_MAP_BASIC: - diffusers_unet_map[UNET_MAP_BASIC[k]] = k + diffusers_unet_map[k[1]] = k[0] return diffusers_unet_map diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 15377af14..b80c8b9a2 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -2,6 +2,35 @@ import torch from nodes import MAX_RESOLUTION +def composite(destination, source, x, y, mask = None, multiplier = 8): + x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) + y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) + + left, top = (x // multiplier, y // multiplier) + right, bottom = (left + source.shape[3], top + source.shape[2],) + + + if mask is None: + mask = torch.ones_like(source) + else: + mask = mask.clone() + mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear") + mask = mask.repeat((source.shape[0], source.shape[1], 1, 1)) + + # calculate the bounds of the source that will be overlapping the destination + # this prevents the source trying to overwrite latent pixels that are out of bounds + # of the destination + visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) + + mask = mask[:, :, :visible_height, :visible_width] + inverse_mask = torch.ones_like(mask) - mask + + source_portion = mask * source[:, :, :visible_height, :visible_width] + destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] + + destination[:, :, top:bottom, left:right] = source_portion + destination_portion + return destination + class LatentCompositeMasked: @classmethod def INPUT_TYPES(s): @@ -25,36 +54,31 @@ class LatentCompositeMasked: output = destination.copy() destination = destination["samples"].clone() source = source["samples"] + output["samples"] = composite(destination, source, x, y, mask, 8) + return (output,) - x = max(-source.shape[3] * 8, min(x, destination.shape[3] * 8)) - y = max(-source.shape[2] * 8, min(y, destination.shape[2] * 8)) +class ImageCompositeMasked: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "destination": ("IMAGE",), + "source": ("IMAGE",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + }, + "optional": { + "mask": ("MASK",), + } + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "composite" - left, top = (x // 8, y // 8) - right, bottom = (left + source.shape[3], top + source.shape[2],) - - - if mask is None: - mask = torch.ones_like(source) - else: - mask = mask.clone() - mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear") - mask = mask.repeat((source.shape[0], source.shape[1], 1, 1)) - - # calculate the bounds of the source that will be overlapping the destination - # this prevents the source trying to overwrite latent pixels that are out of bounds - # of the destination - visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) - - mask = mask[:, :, :visible_height, :visible_width] - inverse_mask = torch.ones_like(mask) - mask - - source_portion = mask * source[:, :, :visible_height, :visible_width] - destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] - - destination[:, :, top:bottom, left:right] = source_portion + destination_portion - - output["samples"] = destination + CATEGORY = "image" + def composite(self, destination, source, x, y, mask = None): + destination = destination.clone().movedim(-1, 1) + output = composite(destination, source.movedim(-1, 1), x, y, mask, 1).movedim(1, -1) return (output,) class MaskToImage: @@ -253,6 +277,7 @@ class FeatherMask: NODE_CLASS_MAPPINGS = { "LatentCompositeMasked": LatentCompositeMasked, + "ImageCompositeMasked": ImageCompositeMasked, "MaskToImage": MaskToImage, "ImageToMask": ImageToMask, "SolidMask": SolidMask, diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 95c4cfece..bce4b3dd0 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -1,9 +1,13 @@ import comfy.sd import comfy.utils +import comfy.model_base + import folder_paths import json import os +from comfy.cli_args import args + class ModelMergeSimple: @classmethod def INPUT_TYPES(s): @@ -99,10 +103,36 @@ class CheckpointSave: if prompt is not None: prompt_info = json.dumps(prompt) - metadata = {"prompt": prompt_info} - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) + metadata = {} + + enable_modelspec = True + if isinstance(model.model, comfy.model_base.SDXL): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" + elif isinstance(model.model, comfy.model_base.SDXLRefiner): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" + else: + enable_modelspec = False + + if enable_modelspec: + metadata["modelspec.sai_model_spec"] = "1.0.0" + metadata["modelspec.implementation"] = "sgm" + metadata["modelspec.title"] = "{} {}".format(filename, counter) + + #TODO: + # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", + # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", + # "v2-inpainting" + + if model.model.model_type == comfy.model_base.ModelType.EPS: + metadata["modelspec.predict_key"] = "epsilon" + elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: + metadata["modelspec.predict_key"] = "v" + + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 3be141dfe..a138b292e 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -59,8 +59,8 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) -def gaussian_kernel(kernel_size: int, sigma: float): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") +def gaussian_kernel(kernel_size: int, sigma: float, device=None): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij") d = torch.sqrt(x * x + y * y) g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) return g / g.sum() @@ -101,7 +101,7 @@ class Blur: batch_size, height, width, channels = image.shape kernel_size = blur_radius * 2 + 1 - kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1) image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index f9252ea0b..abd182e6e 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -37,12 +37,23 @@ class ImageUpscaleWithModel: device = model_management.get_torch_device() upscale_model.to(device) in_img = image.movedim(-1,-3).to(device) + free_memory = model_management.get_free_memory(device) + + tile = 512 + overlap = 32 + + oom = True + while oom: + try: + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) + pbar = comfy.utils.ProgressBar(steps) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + oom = False + except model_management.OOM_EXCEPTION as e: + tile //= 2 + if tile < 128: + raise e - tile = 128 + 64 - overlap = 8 - steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) - pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return (s,) diff --git a/cuda_malloc.py b/cuda_malloc.py new file mode 100644 index 000000000..144cdacd3 --- /dev/null +++ b/cuda_malloc.py @@ -0,0 +1,84 @@ +import os +import importlib.util +from comfy.cli_args import args + +#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. +def get_gpu_names(): + if os.name == 'nt': + import ctypes + + # Define necessary C structures and types + class DISPLAY_DEVICEA(ctypes.Structure): + _fields_ = [ + ('cb', ctypes.c_ulong), + ('DeviceName', ctypes.c_char * 32), + ('DeviceString', ctypes.c_char * 128), + ('StateFlags', ctypes.c_ulong), + ('DeviceID', ctypes.c_char * 128), + ('DeviceKey', ctypes.c_char * 128) + ] + + # Load user32.dll + user32 = ctypes.windll.user32 + + # Call EnumDisplayDevicesA + def enum_display_devices(): + device_info = DISPLAY_DEVICEA() + device_info.cb = ctypes.sizeof(device_info) + device_index = 0 + gpu_names = set() + + while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0): + device_index += 1 + gpu_names.add(device_info.DeviceString.decode('utf-8')) + return gpu_names + return enum_display_devices() + else: + return set() + +blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", + "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620", + "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", + "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000", + "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", + "GeForce GTX 1650", "GeForce GTX 1630" + } + +def cuda_malloc_supported(): + try: + names = get_gpu_names() + except: + names = set() + for x in names: + if "NVIDIA" in x: + for b in blacklist: + if b in x: + return False + return True + + +if not args.cuda_malloc: + try: + version = "" + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ + if int(version[0]) >= 2: #enable by default for torch version 2.0 and up + args.cuda_malloc = cuda_malloc_supported() + except: + pass + + +if args.cuda_malloc and not args.disable_cuda_malloc: + env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) + if env_var is None: + env_var = "backend:cudaMallocAsync" + else: + env_var += ",backend:cudaMallocAsync" + + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 175202aeb..e37808b03 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -51,9 +51,10 @@ class Example: "default": 0, "min": 0, #Minimum value "max": 4096, #Maximum value - "step": 64 #Slider's step + "step": 64, #Slider's step + "display": "number" # Cosmetic only: display as "number" or "slider" }), - "float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}), "print_to_screen": (["enable", "disable"],), "string_field": ("STRING", { "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node diff --git a/execution.py b/execution.py index a40b1dd36..a1a7c75c8 100644 --- a/execution.py +++ b/execution.py @@ -6,7 +6,6 @@ import threading import heapq import traceback import gc -import time import torch import nodes @@ -43,11 +42,14 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): # check if node wants the lists - intput_is_list = False + input_is_list = False if hasattr(obj, "INPUT_IS_LIST"): - intput_is_list = obj.INPUT_IS_LIST + input_is_list = obj.INPUT_IS_LIST - max_len_input = max([len(x) for x in input_data_all.values()]) + if len(input_data_all) == 0: + max_len_input = 0 + else: + max_len_input = max([len(x) for x in input_data_all.values()]) # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): @@ -57,11 +59,15 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): return d_new results = [] - if intput_is_list: + if input_is_list: if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)(**input_data_all)) - else: + elif max_len_input == 0: + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)()) + else: for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() diff --git a/folder_paths.py b/folder_paths.py index eb7d39b88..e321690dd 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -43,6 +43,10 @@ def set_output_directory(output_dir): global output_directory output_directory = output_dir +def set_temp_directory(temp_dir): + global temp_directory + temp_directory = temp_dir + def get_output_directory(): global output_directory return output_directory @@ -111,6 +115,8 @@ def add_model_folder_path(folder_name, full_folder_path): global folder_names_and_paths if folder_name in folder_names_and_paths: folder_names_and_paths[folder_name][0].append(full_folder_path) + else: + folder_names_and_paths[folder_name] = ([full_folder_path], set()) def get_folder_paths(folder_name): return folder_names_and_paths[folder_name][0][:] diff --git a/latent_preview.py b/latent_preview.py index 833e6822e..30c1d1317 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -1,6 +1,5 @@ import torch -from PIL import Image, ImageOps -from io import BytesIO +from PIL import Image import struct import numpy as np from comfy.cli_args import args, LatentPreviewMethod @@ -15,26 +14,7 @@ class LatentPreviewer: def decode_latent_to_preview_image(self, preview_format, x0): preview_image = self.decode_latent_to_preview(x0) - - if hasattr(Image, 'Resampling'): - resampling = Image.Resampling.BILINEAR - else: - resampling = Image.ANTIALIAS - - preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling) - - preview_type = 1 - if preview_format == "JPEG": - preview_type = 1 - elif preview_format == "PNG": - preview_type = 2 - - bytesIO = BytesIO() - header = struct.pack(">I", preview_type) - bytesIO.write(header) - preview_image.save(bytesIO, format=preview_format, quality=95) - preview_bytes = bytesIO.getvalue() - return preview_bytes + return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION) class TAESDPreviewerImpl(LatentPreviewer): def __init__(self, taesd): diff --git a/main.py b/main.py index 802e4bfe4..a4038db4b 100644 --- a/main.py +++ b/main.py @@ -51,7 +51,6 @@ import threading import gc from comfy.cli_args import args -import comfy.utils if os.name == "nt": import logging @@ -62,7 +61,9 @@ if __name__ == "__main__": os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) print("Set cuda device to:", args.cuda_device) + import cuda_malloc +import comfy.utils import yaml import execution @@ -71,6 +72,17 @@ from server import BinaryEventTypes from nodes import init_custom_nodes import comfy.model_management +def cuda_malloc_warning(): + device = comfy.model_management.get_torch_device() + device_name = comfy.model_management.get_torch_device_name(device) + cuda_malloc_warning = False + if "cudaMallocAsync" in device_name: + for b in cuda_malloc.blacklist: + if b in device_name: + cuda_malloc_warning = True + if cuda_malloc_warning: + print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") + def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: @@ -91,15 +103,15 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def hijack_progress(server): - def hook(value, total, preview_image_bytes): + def hook(value, total, preview_image): server.send_sync("progress", {"value": value, "max": total}, server.client_id) - if preview_image_bytes is not None: - server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) + if preview_image is not None: + server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) comfy.utils.set_progress_bar_global_hook(hook) def cleanup_temp(): - temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") + temp_dir = folder_paths.get_temp_directory() if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) @@ -126,6 +138,10 @@ def load_extra_path_config(yaml_path): if __name__ == "__main__": + if args.temp_directory: + temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") + print(f"Setting temp directory to: {temp_dir}") + folder_paths.set_temp_directory(temp_dir) cleanup_temp() loop = asyncio.new_event_loop() @@ -142,6 +158,9 @@ if __name__ == "__main__": load_extra_path_config(config_path) init_custom_nodes() + + cuda_malloc_warning() + server.add_routes() hijack_progress(server) @@ -159,6 +178,8 @@ if __name__ == "__main__": if args.auto_launch: def startup_server(address, port): import webbrowser + if os.name == 'nt' and address == '0.0.0.0': + address = '127.0.0.1' webbrowser.open(f"http://{address}:{port}") call_on_start = startup_server diff --git a/nodes.py b/nodes.py index 32bcf141d..5b144c2fc 100644 --- a/nodes.py +++ b/nodes.py @@ -26,6 +26,8 @@ import comfy.utils import comfy.clip_vision import comfy.model_management +from comfy.cli_args import args + import importlib import folder_paths @@ -204,6 +206,28 @@ class ConditioningZeroOut: c.append(n) return (c, ) +class ConditioningSetTimestepRange: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "set_range" + + CATEGORY = "advanced/conditioning" + + def set_range(self, conditioning, start, end): + c = [] + for t in conditioning: + d = t[1].copy() + d['start_percent'] = 1.0 - start + d['end_percent'] = 1.0 - end + n = [t[0], d] + c.append(n) + return (c, ) + class VAEDecode: @classmethod def INPUT_TYPES(s): @@ -330,12 +354,22 @@ class SaveLatent: if prompt is not None: prompt_info = json.dumps(prompt) - metadata = {"prompt": prompt_info} - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) + metadata = None + if not args.disable_metadata: + metadata = {"prompt": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) file = f"{filename}_{counter:05}_.latent" + + results = list() + results.append({ + "filename": file, + "subfolder": subfolder, + "type": "output" + }) + file = os.path.join(full_output_folder, file) output = {} @@ -343,7 +377,7 @@ class SaveLatent: output["latent_format_version_0"] = torch.tensor([]) comfy.utils.save_torch_file(output, file, metadata=metadata) - return {} + return { "ui": { "latents": results } } class LoadLatent: @@ -497,7 +531,9 @@ class LoraLoader: if self.loaded_lora[0] == lora_path: lora = self.loaded_lora[1] else: - del self.loaded_lora + temp = self.loaded_lora + self.loaded_lora = None + del temp if lora is None: lora = comfy.utils.load_torch_file(lora_path, safe_load=True) @@ -578,9 +614,58 @@ class ControlNetApply: if 'control' in t[1]: c_net.set_previous_controlnet(t[1]['control']) n[1]['control'] = c_net + n[1]['control_apply_to_uncond'] = True c.append(n) return (c, ) + +class ControlNetApplyAdvanced: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING") + RETURN_NAMES = ("positive", "negative") + FUNCTION = "apply_controlnet" + + CATEGORY = "conditioning" + + def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent): + if strength == 0: + return (positive, negative) + + control_hint = image.movedim(-1,1) + cnets = {} + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + + prev_cnet = d.get('control', None) + if prev_cnet in cnets: + c_net = cnets[prev_cnet] + else: + c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent)) + c_net.set_previous_controlnet(prev_cnet) + cnets[prev_cnet] = c_net + + d['control'] = c_net + d['control_apply_to_uncond'] = False + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1]) + + class UNETLoader: @classmethod def INPUT_TYPES(s): @@ -686,7 +771,7 @@ class StyleModelApply: CATEGORY = "conditioning/style_model" def apply_stylemodel(self, clip_vision_output, style_model, conditioning): - cond = style_model.get_cond(clip_vision_output) + cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0) c = [] for t in conditioning: n = [torch.cat((t[0], cond), dim=1), t[1].copy()] @@ -970,6 +1055,47 @@ class LatentComposite: samples_out["samples"] = s return (samples_out,) +class LatentBlend: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "samples1": ("LATENT",), + "samples2": ("LATENT",), + "blend_factor": ("FLOAT", { + "default": 0.5, + "min": 0, + "max": 1, + "step": 0.01 + }), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "blend" + + CATEGORY = "_for_testing" + + def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"): + + samples_out = samples1.copy() + samples1 = samples1["samples"] + samples2 = samples2["samples"] + + if samples1.shape != samples2.shape: + samples2.permute(0, 3, 1, 2) + samples2 = comfy.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center') + samples2.permute(0, 2, 3, 1) + + samples_blended = self.blend_mode(samples1, samples2, blend_mode) + samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor) + samples_out["samples"] = samples_blended + return (samples_out,) + + def blend_mode(self, img1, img2, mode): + if mode == "normal": + return img2 + else: + raise ValueError(f"Unsupported blend mode: {mode}") + class LatentCrop: @classmethod def INPUT_TYPES(s): @@ -1141,12 +1267,14 @@ class SaveImage: for image in images: i = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - metadata = PngInfo() - if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) file = f"{filename}_{counter:05}_.png" img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) @@ -1320,6 +1448,22 @@ class ImageInvert: s = 1.0 - image return (s,) +class ImageBatch: + + @classmethod + def INPUT_TYPES(s): + return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "batch" + + CATEGORY = "image" + + def batch(self, image1, image2): + if image1.shape[1:] != image2.shape[1:]: + image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1) + s = torch.cat((image1, image2), dim=0) + return (s,) class ImagePadForOutpaint: @@ -1405,6 +1549,7 @@ NODE_CLASS_MAPPINGS = { "ImageScale": ImageScale, "ImageScaleBy": ImageScaleBy, "ImageInvert": ImageInvert, + "ImageBatch": ImageBatch, "ImagePadForOutpaint": ImagePadForOutpaint, "ConditioningAverage ": ConditioningAverage , "ConditioningCombine": ConditioningCombine, @@ -1414,6 +1559,7 @@ NODE_CLASS_MAPPINGS = { "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, "LatentComposite": LatentComposite, + "LatentBlend": LatentBlend, "LatentRotate": LatentRotate, "LatentFlip": LatentFlip, "LatentCrop": LatentCrop, @@ -1425,6 +1571,7 @@ NODE_CLASS_MAPPINGS = { "StyleModelApply": StyleModelApply, "unCLIPConditioning": unCLIPConditioning, "ControlNetApply": ControlNetApply, + "ControlNetApplyAdvanced": ControlNetApplyAdvanced, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, "StyleModelLoader": StyleModelLoader, @@ -1442,6 +1589,7 @@ NODE_CLASS_MAPPINGS = { "SaveLatent": SaveLatent, "ConditioningZeroOut": ConditioningZeroOut, + "ConditioningSetTimestepRange": ConditioningSetTimestepRange, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1470,6 +1618,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", + "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "SetLatentNoiseMask": "Set Latent Noise Mask", @@ -1482,6 +1631,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentUpscale": "Upscale Latent", "LatentUpscaleBy": "Upscale Latent By", "LatentComposite": "Latent Composite", + "LatentBlend": "Latent Blend", "LatentFromBatch" : "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", # Image @@ -1494,6 +1644,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageInvert": "Invert Image", "ImagePadForOutpaint": "Pad Image for Outpainting", + "ImageBatch": "Batch Images", # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index 61c277bf6..b1c487101 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -69,6 +69,13 @@ "source": [ "# Checkpoints\n", "\n", + "### SDXL\n", + "### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n", + "\n", + "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n", + "\n", + "\n", "# SD1.5\n", "!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n", "\n", @@ -83,7 +90,7 @@ "#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n", "\n", "# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n", - "#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp16.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n", "\n", "\n", "# unCLIP models\n", @@ -100,6 +107,7 @@ "# Loras\n", "#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n", "#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n", + "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n", "\n", "\n", "# T2I-Adapter\n", @@ -151,13 +159,64 @@ "\n" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "kkkkkkkkkkkkkkk" + }, + "source": [ + "### Run ComfyUI with cloudflared (Recommended Way)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jjjjjjjjjjjjjj" + }, + "outputs": [], + "source": [ + "!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n", + "!dpkg -i cloudflared-linux-amd64.deb\n", + "\n", + "import subprocess\n", + "import threading\n", + "import time\n", + "import socket\n", + "import urllib.request\n", + "\n", + "def iframe_thread(port):\n", + " while True:\n", + " time.sleep(0.5)\n", + " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", + " result = sock.connect_ex(('127.0.0.1', port))\n", + " if result == 0:\n", + " break\n", + " sock.close()\n", + " print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n", + "\n", + " p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", + " for line in p.stderr:\n", + " l = line.decode()\n", + " if \"trycloudflare.com \" in l:\n", + " print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n", + " #print(l, end='')\n", + "\n", + "\n", + "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n", + "\n", + "!python main.py --dont-print-server" + ] + }, { "cell_type": "markdown", "metadata": { "id": "kkkkkkkkkkkkkk" }, "source": [ - "### Run ComfyUI with localtunnel (Recommended Way)\n", + "### Run ComfyUI with localtunnel\n", "\n", "\n" ] diff --git a/requirements.txt b/requirements.txt index d632edf79..14524485a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ torch -torchdiffeq torchsde einops transformers>=4.25.1 @@ -10,3 +9,4 @@ pyyaml Pillow scipy tqdm +psutil diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index a0e22878b..242d3175f 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -2,8 +2,12 @@ import json from urllib import request, parse import random -#this is the ComfyUI api prompt format. If you want it for a specific workflow you can copy it from the prompt section -#of the image metadata of images generated with ComfyUI +#This is the ComfyUI api prompt format. + +#If you want it for a specific workflow you can "enable dev mode options" +#in the settings of the UI (gear beside the "Queue Size: ") this will enable +#a button on the UI to save workflows in api format. + #keep in mind ComfyUI is pre alpha software so this format will change a bit. #this is the one for the default workflow diff --git a/server.py b/server.py index 9ca131ede..fab33be3e 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,7 @@ import uuid import json import glob import struct -from PIL import Image +from PIL import Image, ImageOps from io import BytesIO try: @@ -29,6 +29,7 @@ import comfy.model_management class BinaryEventTypes: PREVIEW_IMAGE = 1 + UNENCODED_PREVIEW_IMAGE = 2 async def send_socket_catch_exception(function, message): try: @@ -344,6 +345,11 @@ class PromptServer(): vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) system_stats = { + "system": { + "os": os.name, + "python_version": sys.version, + "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded" + }, "devices": [ { "name": device_name, @@ -498,7 +504,9 @@ class PromptServer(): return prompt_info async def send(self, event, data, sid=None): - if isinstance(data, (bytes, bytearray)): + if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: + await self.send_image(data, sid=sid) + elif isinstance(data, (bytes, bytearray)): await self.send_bytes(event, data, sid) else: await self.send_json(event, data, sid) @@ -512,6 +520,30 @@ class PromptServer(): message.extend(data) return message + async def send_image(self, image_data, sid=None): + image_type = image_data[0] + image = image_data[1] + max_size = image_data[2] + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.ANTIALIAS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + type_num = 1 + if image_type == "JPEG": + type_num = 1 + elif image_type == "PNG": + type_num = 2 + + bytesIO = BytesIO() + header = struct.pack(">I", type_num) + bytesIO.write(header) + image.save(bytesIO, format=image_type, quality=95, compress_level=4) + preview_bytes = bytesIO.getvalue() + await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) + async def send_bytes(self, event, data, sid=None): message = self.encode_bytes(event, data) diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index 662d87e74..152cd7043 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -1,4 +1,4 @@ -import {app} from "/scripts/app.js"; +import {app} from "../../scripts/app.js"; // Adds filtering to combo context menus @@ -27,10 +27,13 @@ const ext = { const clickedComboValue = currentNode.widgets .filter(w => w.type === "combo" && w.options.values.length === values.length) .find(w => w.options.values.every((v, i) => v === values[i])) - .value; + ?.value; - let selectedIndex = values.findIndex(v => v === clickedComboValue); - let selectedItem = displayedItems?.[selectedIndex]; + let selectedIndex = clickedComboValue ? values.findIndex(v => v === clickedComboValue) : 0; + if (selectedIndex < 0) { + selectedIndex = 0; + } + let selectedItem = displayedItems[selectedIndex]; updateSelected(); // Apply highlighting to the selected item diff --git a/web/extensions/core/linkRenderMode.js b/web/extensions/core/linkRenderMode.js new file mode 100644 index 000000000..1e9091ec1 --- /dev/null +++ b/web/extensions/core/linkRenderMode.js @@ -0,0 +1,25 @@ +import { app } from "../../scripts/app.js"; + +const id = "Comfy.LinkRenderMode"; +const ext = { + name: id, + async setup(app) { + app.ui.settings.addSetting({ + id, + name: "Link Render Mode", + defaultValue: 2, + type: "combo", + options: LiteGraph.LINK_RENDER_MODES.map((m, i) => ({ + value: i, + text: m, + selected: i == app.canvas.links_render_mode, + })), + onChange(value) { + app.canvas.links_render_mode = +value; + app.graph.setDirtyCanvas(true); + }, + }); + }, +}; + +app.registerExtension(ext); diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 7600ce87b..d9eaf8a0c 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -2,7 +2,7 @@ import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js"; import { app } from "../../scripts/app.js"; const CONVERTED_TYPE = "converted-widget"; -const VALID_TYPES = ["STRING", "combo", "number"]; +const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; function isConvertableWidget(widget, config) { return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]); diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 2a33bd4a7..356c71ac2 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -9766,6 +9766,7 @@ LGraphNode.prototype.executeAction = function(action) switch (w.type) { case "button": + ctx.fillStyle = background_color; if (w.clicked) { ctx.fillStyle = "#AAA"; w.clicked = false; @@ -9835,7 +9836,11 @@ LGraphNode.prototype.executeAction = function(action) ctx.textAlign = "center"; ctx.fillStyle = text_color; ctx.fillText( - w.label || w.name + " " + Number(w.value).toFixed(3), + w.label || w.name + " " + Number(w.value).toFixed( + w.options.precision != null + ? w.options.precision + : 3 + ), widget_width * 0.5, y + H * 0.7 ); @@ -13835,7 +13840,7 @@ LGraphNode.prototype.executeAction = function(action) if (!disabled) { element.addEventListener("click", inner_onclick); } - if (options.autoopen) { + if (!disabled && options.autoopen) { LiteGraph.pointerListenerAdd(element,"enter",inner_over); } diff --git a/web/scripts/api.js b/web/scripts/api.js index d3d15e47e..b1d245d73 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -264,6 +264,15 @@ class ComfyApi extends EventTarget { } } + /** + * Gets system & device stats + * @returns System stats such as python version, OS, per device info + */ + async getSystemStats() { + const res = await this.fetchApi("/system_stats"); + return await res.json(); + } + /** * Sends a POST request to the API * @param {*} type The endpoint to post to diff --git a/web/scripts/app.js b/web/scripts/app.js index 56b060aa6..97d1d51be 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1,3 +1,4 @@ +import { ComfyLogging } from "./logging.js"; import { ComfyWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; @@ -31,6 +32,7 @@ export class ComfyApp { constructor() { this.ui = new ComfyUI(this); + this.logging = new ComfyLogging(this); /** * List of extensions that are registered with the app @@ -282,6 +284,11 @@ export class ComfyApp { } } + options.push({ + content: "Bypass", + callback: (obj) => { if (this.mode === 4) this.mode = 0; else this.mode = 4; this.graph.change(); } + }); + // prevent conflict of clipspace content if(!ComfyApp.clipspace_return_node) { options.push({ @@ -768,6 +775,19 @@ export class ComfyApp { } block_default = true; } + + if (e.keyCode == 66 && e.ctrlKey) { + if (this.selected_nodes) { + for (var i in this.selected_nodes) { + if (this.selected_nodes[i].mode === 4) { // never + this.selected_nodes[i].mode = 0; // always + } else { + this.selected_nodes[i].mode = 4; // never + } + } + } + block_default = true; + } } this.graph.change(); @@ -914,14 +934,21 @@ export class ComfyApp { const origDrawNode = LGraphCanvas.prototype.drawNode; LGraphCanvas.prototype.drawNode = function (node, ctx) { var editor_alpha = this.editor_alpha; + var old_color = node.bgcolor; if (node.mode === 2) { // never this.editor_alpha = 0.4; } + if (node.mode === 4) { // never + node.bgcolor = "#FF00FF"; + this.editor_alpha = 0.2; + } + const res = origDrawNode.apply(this, arguments); this.editor_alpha = editor_alpha; + node.bgcolor = old_color; return res; }; @@ -1003,6 +1030,7 @@ export class ComfyApp { */ async #loadExtensions() { const extensions = await api.getExtensions(); + this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions }); for (const ext of extensions) { try { await import(api.apiURL(ext)); @@ -1286,6 +1314,9 @@ export class ComfyApp { (t) => `