Merge branch 'comfyanonymous:master' into feature/blockweights

This commit is contained in:
Dr.Lt.Data 2023-05-05 17:19:22 +09:00 committed by GitHub
commit 8895cbe7b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 206 additions and 126 deletions

View File

@ -5,17 +5,17 @@ import torch
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from ldm.modules.diffusionmodules.util import ( from ..ldm.modules.diffusionmodules.util import (
conv_nd, conv_nd,
linear, linear,
zero_module, zero_module,
timestep_embedding, timestep_embedding,
) )
from ldm.modules.attention import SpatialTransformer from ..ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion from ..ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, exists, instantiate_from_config from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
class ControlledUnetModel(UNetModel): class ControlledUnetModel(UNetModel):

View File

@ -767,7 +767,7 @@ class UniPC:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x model_prev_list[-1] = model_x
if callback is not None: if callback is not None:
callback(step_index, model_prev_list[-1], x) callback(step_index, model_prev_list[-1], x, steps)
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: if denoise_to_zero:

View File

@ -1,6 +1,6 @@
import torch import torch
from torch import nn, einsum from torch import nn, einsum
from ldm.modules.attention import CrossAttention from .ldm.modules.attention import CrossAttention
from inspect import isfunction from inspect import isfunction

View File

@ -3,11 +3,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
# class AutoencoderKL(pl.LightningModule): # class AutoencoderKL(pl.LightningModule):
class AutoencoderKL(torch.nn.Module): class AutoencoderKL(torch.nn.Module):

View File

@ -4,7 +4,7 @@ import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object): class DDIMSampler(object):

View File

@ -19,12 +19,12 @@ from tqdm import tqdm
from torchvision.utils import make_grid from torchvision.utils import make_grid
# from pytorch_lightning.utilities.distributed import rank_zero_only # from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL from ..autoencoder import IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler from .ddim import DDIMSampler
__conditioning_keys__ = {'concat': 'c_concat', __conditioning_keys__ = {'concat': 'c_concat',

View File

@ -6,7 +6,7 @@ from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
@ -21,7 +21,7 @@ if model_management.xformers_enabled():
import os import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
from cli_args import args from comfy.cli_args import args
def exists(val): def exists(val):
return val is not None return val is not None

View File

@ -6,7 +6,7 @@ import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention from ..attention import MemoryEfficientCrossAttention
from comfy import model_management from comfy import model_management
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():

View File

@ -6,7 +6,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ldm.modules.diffusionmodules.util import ( from .util import (
checkpoint, checkpoint,
conv_nd, conv_nd,
linear, linear,
@ -15,8 +15,8 @@ from ldm.modules.diffusionmodules.util import (
normalization, normalization,
timestep_embedding, timestep_embedding,
) )
from ldm.modules.attention import SpatialTransformer from ..attention import SpatialTransformer
from ldm.util import exists from comfy.ldm.util import exists
# dummy replace # dummy replace
@ -76,12 +76,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None, transformer_options={}): def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options) x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else: else:
x = layer(x) x = layer(x)
return x return x
@ -105,14 +107,20 @@ class Upsample(nn.Module):
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" if output_shape is not None:
) shape[1] = output_shape[3]
shape[2] = output_shape[4]
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x
@ -813,9 +821,14 @@ class UNetModel(nn.Module):
ctrl = control['output'].pop() ctrl = control['output'].pop()
if ctrl is not None: if ctrl is not None:
hsp += ctrl hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context, transformer_options) if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = module(h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)

View File

@ -3,8 +3,8 @@ import torch.nn as nn
import numpy as np import numpy as np
from functools import partial from functools import partial
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from .util import extract_into_tensor, make_beta_schedule
from ldm.util import default from comfy.ldm.util import default
class AbstractLowScaleModel(nn.Module): class AbstractLowScaleModel(nn.Module):

View File

@ -15,7 +15,7 @@ import torch.nn as nn
import numpy as np import numpy as np
from einops import repeat from einops import repeat
from ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):

View File

@ -1,5 +1,5 @@
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.diffusionmodules.openaimodel import Timestep from ..diffusionmodules.openaimodel import Timestep
import torch import torch
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):

View File

@ -1,6 +1,6 @@
import psutil import psutil
from enum import Enum from enum import Enum
from cli_args import args from comfy.cli_args import args
class VRAMState(Enum): class VRAMState(Enum):
CPU = 0 CPU = 0

View File

@ -623,7 +623,8 @@ class KSampler:
ddim_callback = None ddim_callback = None
if callback is not None: if callback is not None:
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None) total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
sampler = DDIMSampler(self.model, device=self.device) sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
@ -654,13 +655,14 @@ class KSampler:
noise = noise * sigmas[0] noise = noise * sigmas[0]
k_callback = None k_callback = None
total_steps = len(sigmas) - 1
if callback is not None: if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"]) k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None: if latent_image is not None:
noise += latent_image noise += latent_image
if self.sampler == "dpm_fast": if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive": elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else: else:

View File

@ -2,8 +2,8 @@ import torch
import contextlib import contextlib
import copy import copy
import sd1_clip from . import sd1_clip
import sd2_clip from . import sd2_clip
from comfy import model_management from comfy import model_management
from .ldm.util import instantiate_from_config from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL from .ldm.models.autoencoder import AutoencoderKL
@ -495,10 +495,10 @@ class CLIP:
else: else:
params = {} params = {}
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder": if self.target_clip.endswith("FrozenOpenCLIPEmbedder"):
clip = sd2_clip.SD2ClipModel clip = sd2_clip.SD2ClipModel
tokenizer = sd2_clip.SD2Tokenizer tokenizer = sd2_clip.SD2Tokenizer
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder": elif self.target_clip.endswith("FrozenCLIPEmbedder"):
clip = sd1_clip.SD1ClipModel clip = sd1_clip.SD1ClipModel
tokenizer = sd1_clip.SD1Tokenizer tokenizer = sd1_clip.SD1Tokenizer
@ -563,11 +563,16 @@ class VAE:
self.device = device self.device = device
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
output = torch.clamp(( output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8)) utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0) / 3.0) / 2.0, min=0.0, max=1.0)
return output return output
@ -611,9 +616,15 @@ class VAE:
model_management.unload_model() model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1).to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4) steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4) steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples /= 3.0 samples /= 3.0
self.first_stage_model = self.first_stage_model.cpu() self.first_stage_model = self.first_stage_model.cpu()
samples = samples.cpu() samples = samples.cpu()
@ -934,9 +945,9 @@ def load_clip(ckpt_path, embedding_directory=None):
clip_data = utils.load_torch_file(ckpt_path) clip_data = utils.load_torch_file(ckpt_path)
config = {} config = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
else: else:
config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
clip = CLIP(config=config, embedding_directory=embedding_directory) clip = CLIP(config=config, embedding_directory=embedding_directory)
clip.load_from_state_dict(clip_data) clip.load_from_state_dict(clip_data)
return clip return clip
@ -1012,9 +1023,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clip: if output_clip:
clip_config = {} clip_config = {}
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys: if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys:
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
else: else:
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
clip = CLIP(config=clip_config, embedding_directory=embedding_directory) clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w] load_state_dict_to = [w]
@ -1035,7 +1046,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0] noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2" noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
params["noise_schedule_config"] = noise_schedule_config params["noise_schedule_config"] = noise_schedule_config
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
if size == 1280: #h if size == 1280: #h
params["timestep_dim"] = 1024 params["timestep_dim"] = 1024
elif size == 1024: #l elif size == 1024: #l
@ -1087,19 +1098,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
if noise_aug_config is not None: #SD2.x unclip model if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96 sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25 sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm' sd_config["conditioning_key"] = 'crossattn-adm'
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid" sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None sd_config["finetune_keys"] = None
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
else: else:
sd_config["conditioning_key"] = "crossattn" sd_config["conditioning_key"] = "crossattn"

View File

@ -191,11 +191,20 @@ def safe_load_embed_zip(embed_path):
del embed del embed
return out return out
def expand_directory_list(directories):
dirs = set()
for x in directories:
dirs.add(x)
for root, subdir, file in os.walk(x, followlinks=True):
dirs.add(root)
return list(dirs)
def load_embed(embedding_name, embedding_directory): def load_embed(embedding_name, embedding_directory):
if isinstance(embedding_directory, str): if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory] embedding_directory = [embedding_directory]
embedding_directory = expand_directory_list(embedding_directory)
valid_file = None valid_file = None
for embed_dir in embedding_directory: for embed_dir in embedding_directory:
embed_path = os.path.join(embed_dir, embedding_name) embed_path = os.path.join(embed_dir, embedding_name)

View File

@ -1,4 +1,5 @@
import torch import torch
import math
def load_torch_file(ckpt, safe_load=False): def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
@ -62,8 +63,11 @@ def common_upscale(samples, width, height, upscale_method, crop):
s = samples s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
@torch.inference_mode() @torch.inference_mode()
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
for b in range(samples.shape[0]): for b in range(samples.shape[0]):
s = samples[b:b+1] s = samples[b:b+1]
@ -83,6 +87,33 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask
out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask
if pbar is not None:
pbar.update(1)
output[b:b+1] = out/out_div output[b:b+1] = out/out_div
return output return output
PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value, total=None):
if total is not None:
self.total = total
if value > self.total:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total)
def update(self, value):
self.update_absolute(self.current + value)

View File

@ -18,6 +18,7 @@ def load_hypernetwork_patch(path, strength):
"swish": torch.nn.Hardswish, "swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh, "tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid, "sigmoid": torch.nn.Sigmoid,
"softsign": torch.nn.Softsign,
} }
if activation_func not in valid_activation: if activation_func not in valid_activation:

View File

@ -37,7 +37,12 @@ class ImageUpscaleWithModel:
device = model_management.get_torch_device() device = model_management.get_torch_device()
upscale_model.to(device) upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device) in_img = image.movedim(-1,-3).to(device)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=128 + 64, tile_y=128 + 64, overlap = 8, upscale_amount=upscale_model.scale)
tile = 128 + 64
overlap = 8
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
upscale_model.cpu() upscale_model.cpu()
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,) return (s,)

12
main.py
View File

@ -5,6 +5,7 @@ import shutil
import threading import threading
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils
if os.name == "nt": if os.name == "nt":
import logging import logging
@ -39,14 +40,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
from tqdm.auto import tqdm def hook(value, total):
orig_func = getattr(tqdm, "update") server.send_sync("progress", { "value": value, "max": total}, server.client_id)
def wrapped_func(*args, **kwargs): comfy.utils.set_progress_bar_global_hook(hook)
pbar = args[0]
v = orig_func(*args, **kwargs)
server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
return v
setattr(tqdm, "update", wrapped_func)
def cleanup_temp(): def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")

View File

@ -94,10 +94,10 @@ class ConditioningSetArea:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ), return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
@ -188,16 +188,21 @@ class VAEEncode:
CATEGORY = "latent" CATEGORY = "latent"
def encode(self, vae, pixels): @staticmethod
x = (pixels.shape[1] // 64) * 64 def vae_encode_crop_pixels(pixels):
y = (pixels.shape[2] // 64) * 64 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:] x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def encode(self, vae, pixels):
pixels = self.vae_encode_crop_pixels(pixels)
t = vae.encode(pixels[:,:,:,:3]) t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeTiled: class VAEEncodeTiled:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -211,13 +216,10 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels): def encode(self, vae, pixels):
x = (pixels.shape[1] // 64) * 64 pixels = VAEEncode.vae_encode_crop_pixels(pixels)
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
t = vae.encode_tiled(pixels[:,:,:,:3]) t = vae.encode_tiled(pixels[:,:,:,:3])
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeForInpaint: class VAEEncodeForInpaint:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -231,14 +233,16 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6): def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // 64) * 64 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 64) * 64 y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone() pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:] x_offset = (pixels.shape[1] % 8) // 2
mask = mask[:,:,:x,:y] y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
#grow mask by a few pixels to keep things seamless in latent space #grow mask by a few pixels to keep things seamless in latent space
if grow_mask_by == 0: if grow_mask_by == 0:
@ -686,8 +690,8 @@ class EmptyLatentImage:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "generate" FUNCTION = "generate"
@ -725,8 +729,8 @@ class LatentUpscale:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"crop": (s.crop_methods,)}} "crop": (s.crop_methods,)}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale" FUNCTION = "upscale"
@ -828,8 +832,8 @@ class LatentCrop:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": { "samples": ("LATENT",),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}} }}
@ -854,16 +858,6 @@ class LatentCrop:
new_width = width // 8 new_width = width // 8
to_x = new_width + x to_x = new_width + x
to_y = new_height + y to_y = new_height + y
def enforce_image_dim(d, to_d, max_d):
if to_d > max_d:
leftover = (to_d - max_d) % 8
to_d = max_d
d -= leftover
return (d, to_d)
#make sure size is always multiple of 64
x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
s['samples'] = samples[:,:,y:to_y, x:to_x] s['samples'] = samples[:,:,y:to_y, x:to_x]
return (s,) return (s,)
@ -897,9 +891,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask) force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback)
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
return (out, ) return (out, )
@ -1181,10 +1179,10 @@ class ImagePadForOutpaint:
return { return {
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
} }
} }

View File

@ -263,6 +263,34 @@ export class ComfyApp {
*/ */
#addDrawBackgroundHandler(node) { #addDrawBackgroundHandler(node) {
const app = this; const app = this;
function getImageTop(node) {
let shiftY;
if (node.imageOffset != null) {
shiftY = node.imageOffset;
} else {
if (node.widgets?.length) {
const w = node.widgets[node.widgets.length - 1];
shiftY = w.last_y;
if (w.computeSize) {
shiftY += w.computeSize()[1] + 4;
} else {
shiftY += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
} else {
shiftY = node.computeSize()[1];
}
}
return shiftY;
}
node.prototype.setSizeForImage = function () {
const minHeight = getImageTop(this) + 220;
if (this.size[1] < minHeight) {
this.setSize([this.size[0], minHeight]);
}
};
node.prototype.onDrawBackground = function (ctx) { node.prototype.onDrawBackground = function (ctx) {
if (!this.flags.collapsed) { if (!this.flags.collapsed) {
const output = app.nodeOutputs[this.id + ""]; const output = app.nodeOutputs[this.id + ""];
@ -283,9 +311,7 @@ export class ComfyApp {
).then((imgs) => { ).then((imgs) => {
if (this.images === output.images) { if (this.images === output.images) {
this.imgs = imgs.filter(Boolean); this.imgs = imgs.filter(Boolean);
if (this.size[1] < 100) { this.setSizeForImage?.();
this.size[1] = 250;
}
app.graph.setDirtyCanvas(true); app.graph.setDirtyCanvas(true);
} }
}); });
@ -310,12 +336,7 @@ export class ComfyApp {
this.imageIndex = imageIndex = 0; this.imageIndex = imageIndex = 0;
} }
let shiftY; const shiftY = getImageTop(this);
if (this.imageOffset != null) {
shiftY = this.imageOffset;
} else {
shiftY = this.computeSize()[1];
}
let dw = this.size[0]; let dw = this.size[0];
let dh = this.size[1]; let dh = this.size[1];

View File

@ -261,20 +261,13 @@ export const ComfyWidgets = {
let uploadWidget; let uploadWidget;
function showImage(name) { function showImage(name) {
// Position the image somewhere sensible
if (!node.imageOffset) {
node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75;
}
const img = new Image(); const img = new Image();
img.onload = () => { img.onload = () => {
node.imgs = [img]; node.imgs = [img];
app.graph.setDirtyCanvas(true); app.graph.setDirtyCanvas(true);
}; };
img.src = `/view?filename=${name}&type=input`; img.src = `/view?filename=${name}&type=input`;
if ((node.size[1] - node.imageOffset) < 100) { node.setSizeForImage?.();
node.size[1] = 250 + node.imageOffset;
}
} }
// Add our own callback to the combo widget to render an image when it changes // Add our own callback to the combo widget to render an image when it changes