Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-04-08 10:02:37 -07:00
commit 034ffcea03
22 changed files with 691 additions and 284 deletions

View File

@ -139,11 +139,12 @@ class ControlBase:
return out
class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
if control_model is not None:
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
@ -184,7 +185,9 @@ class ControlNet(ControlBase):
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c

View File

@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
def cat_tensors(tensors):
x = 0
for t in tensors:
x += t.shape[0]
shape = [x] + list(tensors[0].shape)[1:]
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
x = 0
for t in tensors:
out[x:x + t.shape[0]] = t
x += t.shape[0]
return out
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {}
@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
for k_pre, tensors in capture_qkv_bias.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
return new_state_dict

View File

@ -10,6 +10,7 @@ from .ldm.cascade.stage_c import StageC
from .ldm.cascade.stage_b import StageB
from enum import Enum
from . import utils
from . import latent_formats
class ModelType(Enum):
EPS = 1
@ -66,7 +67,8 @@ class BaseModel(torch.nn.Module):
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
self.inpaint_model = False
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
@ -107,8 +109,7 @@ class BaseModel(torch.nn.Module):
def extra_conds(self, **kwargs):
out = {}
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
if len(self.concat_keys) > 0:
cond_concat = []
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
concat_latent_image = kwargs.get("concat_latent_image", None)
@ -125,24 +126,16 @@ class BaseModel(torch.nn.Module):
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
if denoise_mask is not None:
if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
for ck in concat_keys:
for ck in self.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask.to(device))
@ -152,7 +145,7 @@ class BaseModel(torch.nn.Module):
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs)
@ -220,7 +213,16 @@ class BaseModel(torch.nn.Module):
return unet_state_dict
def set_inpaint(self):
self.inpaint_model = True
self.concat_keys = ("mask", "masked_image")
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
self.blank_inpaint_image_like = blank_inpaint_image_like
def memory_required(self, input_shape):
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
@ -471,6 +473,42 @@ class SD_X4Upscaler(BaseModel):
out['y'] = conds.CONDRegular(noise_level)
return out
class IP2P:
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
if model_type == ModelType.V_PREDICTION_EDM:
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC)

View File

@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
return unet_config
def model_config_from_unet_config(unet_config):
def model_config_from_unet_config(unet_config, state_dict=None):
for model_config in supported_models.models:
if model_config.matches(unet_config):
if model_config.matches(unet_config, state_dict):
return model_config(unet_config)
logging.error("no match {}".format(unet_config))
@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config)
model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match:
return supported_models_base.BASE(unet_config)
else:
@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS]
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p]
for unet_config in supported_models:
matches = True

View File

@ -396,6 +396,7 @@ def load_models_gpu(models, memory_required=0):
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
models = set(models)
models_to_load = []
models_already_loaded = []
for x in models:

View File

@ -150,6 +150,15 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def get_model_object(self, name):
if name in self.object_patches:
return self.object_patches[name]
else:
if name in self.object_patches_backup:
return self.object_patches_backup[name]
else:
return utils.get_attr(self.model, name)
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
@ -278,7 +287,7 @@ class ModelPatcher:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
if bias_key in self.patches:
m.bias_function = LowVramPatch(weight_key, self)
m.bias_function = LowVramPatch(bias_key, self)
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
@ -462,4 +471,4 @@ class ModelPatcher:
for k in keys:
utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}
self.object_patches_backup.clear()

10
comfy/node_helpers.py Normal file
View File

@ -0,0 +1,10 @@
def conditioning_set_values(conditioning, values={}):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
n[1][k] = values[k]
c.append(n)
return c

View File

@ -30,7 +30,7 @@ from ..model_downloader import get_filename_list_with_downloadable, get_or_downl
from ..nodes.common import MAX_RESOLUTION
from .. import controlnet
from ..open_exr import load_exr
from .. import node_helpers
class CLIPTextEncode:
@classmethod
@ -140,13 +140,9 @@ class ConditioningSetArea:
CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
c.append(n)
c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
"strength": strength,
"set_area_to_bounds": False})
return (c, )
class ConditioningSetAreaPercentage:
@ -165,13 +161,9 @@ class ConditioningSetAreaPercentage:
CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['area'] = ("percentage", height, width, y, x)
n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
c.append(n)
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
"strength": strength,
"set_area_to_bounds": False})
return (c, )
class ConditioningSetAreaStrength:
@ -186,11 +178,7 @@ class ConditioningSetAreaStrength:
CATEGORY = "conditioning"
def append(self, conditioning, strength):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['strength'] = strength
c.append(n)
c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
return (c, )
@ -208,19 +196,15 @@ class ConditioningSetMask:
CATEGORY = "conditioning"
def append(self, conditioning, mask, set_cond_area, strength):
c = []
set_area_to_bounds = False
if set_cond_area != "default":
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
for t in conditioning:
n = [t[0], t[1].copy()]
_, h, w = mask.shape
n[1]['mask'] = mask
n[1]['set_area_to_bounds'] = set_area_to_bounds
n[1]['mask_strength'] = strength
c.append(n)
c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
"set_area_to_bounds": set_area_to_bounds,
"mask_strength": strength})
return (c, )
class ConditioningZeroOut:
@ -255,13 +239,8 @@ class ConditioningSetTimestepRange:
CATEGORY = "advanced/conditioning"
def set_range(self, conditioning, start, end):
c = []
for t in conditioning:
d = t[1].copy()
d['start_percent'] = start
d['end_percent'] = end
n = [t[0], d]
c.append(n)
c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
"end_percent": end})
return (c, )
class VAEDecode:
@ -402,13 +381,8 @@ class InpaintModelConditioning:
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
"concat_mask": mask})
out.append(c)
return (out[0], out[1], out_latent)

View File

@ -8,14 +8,14 @@ import sys
import time
import types
from contextlib import contextmanager
from typing import Dict, List
from typing import Dict, List, Iterable
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
from . import base_nodes
from .package_typing import ExportedNodes
def _vanilla_load_importing_execute_prestartup_script(node_paths: List[str]) -> None:
def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None:
def execute_script(script_path):
module_name = splitext(script_path)[0]
try:
@ -121,7 +121,7 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
return exported_nodes
def _vanilla_load_custom_nodes_2(node_paths: List[str]) -> ExportedNodes:
def _vanilla_load_custom_nodes_2(node_paths: Iterable[str]) -> ExportedNodes:
base_node_names = set(base_nodes.NODE_CLASS_MAPPINGS.keys())
node_import_times = []
exported_nodes = ExportedNodes()
@ -192,6 +192,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
if is_git_repository:
node_paths += [abspath(join(potential_git_dir_parent, "custom_nodes"))]
node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths)
_vanilla_load_importing_execute_prestartup_script(node_paths)
vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths)
return vanilla_custom_nodes

View File

@ -5,6 +5,7 @@ from . import utils
from . import conds
import math
import numpy as np
import logging
def prepare_noise(latent_image, seed, noise_inds=None):
"""
@ -25,94 +26,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
noises = torch.cat(noises, axis=0)
return noises
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
return model, positive, negative, noise_mask, []
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device
positive = convert_cond(positive)
negative = convert_cond(negative)
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, positive, negative, noise_mask, models
logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
sampler = samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sigmas = sigmas.to(model.load_device)
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples

77
comfy/sampler_helpers.py Normal file
View File

@ -0,0 +1,77 @@
import torch
from comfy import model_management
from comfy import utils
from comfy import conds
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, conds):
device = model.load_device
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, conds, models
def cleanup_models(conds, models):
cleanup_additional_models(models)
control_cleanup = []
for k in conds:
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))

View File

@ -5,6 +5,7 @@ import collections
from . import model_management
import math
import logging
from . import sampler_helpers
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
@ -130,30 +131,23 @@ def cond_cat(c_list):
return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
def calc_cond_batch(model, conds, x_in, timestep, model_options):
out_conds = []
out_counts = []
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
to_run += [(p, UNCOND)]
cond = conds[i]
if cond is not None:
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, i)]
while len(to_run) > 0:
first = to_run[0]
@ -225,74 +219,66 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
cond_index = cond_or_uncond[o]
out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module):
class KSamplerX0Inpaint:
def __init__(self, model, sigmas):
super().__init__()
self.inner_model = model
self.sigmas = sigmas
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
return out
def simple_scheduler(model, steps):
s = model.model_sampling
def simple_scheduler(model_sampling, steps):
s = model_sampling
sigs = []
ss = len(s.sigmas) / steps
for x in range(steps):
@ -300,8 +286,8 @@ def simple_scheduler(model, steps):
sigs += [0.0]
return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps):
s = model.model_sampling
def ddim_scheduler(model_sampling, steps):
s = model_sampling
sigs = []
ss = max(len(s.sigmas) // steps, 1)
x = 1
@ -312,8 +298,8 @@ def ddim_scheduler(model, steps):
sigs += [0.0]
return torch.FloatTensor(sigs)
def normal_scheduler(model, steps, sgm=False, floor=False):
s = model.model_sampling
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
s = model_sampling
start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min)
@ -574,59 +560,120 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
return model_denoise
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
positive = positive[:]
negative = negative[:]
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
model_wrap = wrap_model(model)
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image)
for k in conds:
calculate_start_end_timesteps(model, conds[k])
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
for c in negative:
create_cond_with_same_area_if_none(positive, c)
for k in conds:
for c in conds[k]:
for kk in conds:
if k != kk:
create_cond_with_same_area_if_none(conds[kk], c)
pre_run_control(model, negative + positive)
for k in conds:
pre_run_control(model, conds[k])
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if "positive" in conds:
positive = conds["positive"]
for k in conds:
if k != "positive":
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
return conds
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
def set_conds(self, positive, negative):
self.inner_set_conds({"positive": positive, "negative": negative})
def set_cfg(self, cfg):
self.cfg = cfg
def inner_set_conds(self, conds):
for k in conds:
self.original_conds[k] = sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": self.model_options, "seed":seed}
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
self.inner_model, self.conds, self.loaded_models = sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.conds
del self.loaded_models
return output
def calculate_sigmas_scheduler(model, scheduler_name, steps):
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
def calculate_sigmas(model_sampling, scheduler_name, steps):
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "normal":
sigmas = normal_scheduler(model, steps)
sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model, steps)
sigmas = simple_scheduler(model_sampling, steps)
elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model, steps)
sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True)
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas
@ -668,7 +715,7 @@ class KSampler:
steps += 1
discard_penultimate_sigma = True
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps)
sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
@ -679,9 +726,12 @@ class KSampler:
if denoise is None or denoise > 0.9999:
self.sigmas = self.calculate_sigmas(steps).to(self.device)
else:
new_steps = int(steps/denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):]
if denoise <= 0.0:
self.sigmas = torch.FloatTensor([])
else:
new_steps = int(steps/denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
if sigmas is None:

View File

@ -600,7 +600,7 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None
load_models = [model]
if clip is not None:
@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]
utils.save_torch_file(sd, output_path, metadata=metadata)

View File

@ -174,6 +174,11 @@ class SDXL(supported_models_base.BASE):
self.sampling_settings["sigma_max"] = 80.0
self.sampling_settings["sigma_min"] = 0.002
return model_base.ModelType.EDM
elif "edm_vpred.sigma_max" in state_dict:
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
if "edm_vpred.sigma_min" in state_dict:
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
return model_base.ModelType.V_PREDICTION_EDM
elif "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION
else:
@ -334,6 +339,11 @@ class Stable_Zero123(supported_models_base.BASE):
"num_head_channels": -1,
}
required_keys = {
"cc_projection.weight": None,
"cc_projection.bias": None,
}
clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15
@ -439,6 +449,33 @@ class Stable_Cascade_B(Stable_Cascade_C):
out = model_base.StableCascade_B(self, device=device)
return out
class SD15_instructpix2pix(SD15):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SD15_instructpix2pix(self, device=device)
class SDXL_instructpix2pix(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models += [SVD_img2vid]

View File

@ -16,6 +16,8 @@ class BASE:
"num_head_channels": 64,
}
required_keys = {}
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
@ -28,10 +30,14 @@ class BASE:
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
def matches(s, unet_config, state_dict=None):
for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
if state_dict is not None:
for k in s.required_keys:
if k not in state_dict:
return False
return True
def model_type(self, state_dict, prefix=""):

View File

@ -7247,7 +7247,7 @@ LGraphNode.prototype.executeAction = function(action)
//create links
for (var i = 0; i < clipboard_info.links.length; ++i) {
var link_info = clipboard_info.links[i];
var origin_node;
var origin_node = undefined;
var origin_node_relative_id = link_info[0];
if (origin_node_relative_id != null) {
origin_node = nodes[origin_node_relative_id];

View File

@ -170,9 +170,12 @@ export async function importA1111(graph, parameters) {
const opts = parameters
.substr(p)
.split("\n")[1]
.split(",")
.match(new RegExp("\\s*([^:]+:\\s*([^\"\\{].*?|\".*?\"|\\{.*?\\}))\\s*(,|$)", "g"))
.reduce((p, n) => {
const s = n.split(":");
if (s[1].endsWith(',')) {
s[1] = s[1].substr(0, s[1].length -1);
}
p[s[0].trim().toLowerCase()] = s[1].trim();
return p;
}, {});
@ -191,6 +194,7 @@ export async function importA1111(graph, parameters) {
const vaeLoaderNode = LiteGraph.createNode("VAELoader");
const saveNode = LiteGraph.createNode("SaveImage");
let hrSamplerNode = null;
let hrSteps = null;
const ceil64 = (v) => Math.ceil(v / 64) * 64;
@ -290,6 +294,9 @@ export async function importA1111(graph, parameters) {
model(v) {
setWidgetValue(ckptNode, "ckpt_name", v, true);
},
"vae"(v) {
setWidgetValue(vaeLoaderNode, "vae_name", v, true);
},
"cfg scale"(v) {
setWidgetValue(samplerNode, "cfg", +v);
},
@ -316,6 +323,7 @@ export async function importA1111(graph, parameters) {
const h = ceil64(+wxh[1]);
const hrUp = popOpt("hires upscale");
const hrSz = popOpt("hires resize");
hrSteps = popOpt("hires steps");
let hrMethod = popOpt("hires upscaler");
setWidgetValue(imageNode, "width", w);
@ -398,7 +406,7 @@ export async function importA1111(graph, parameters) {
}
if (hrSamplerNode) {
setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value);
setWidgetValue(hrSamplerNode, "steps", hrSteps? +hrSteps : getWidget(samplerNode, "steps").value);
setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value);
setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value);
setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value);
@ -415,7 +423,7 @@ export async function importA1111(graph, parameters) {
graph.arrange();
for (const opt of ["model hash", "ensd"]) {
for (const opt of ["model hash", "ensd", "version", "vae hash", "ti hashes", "lora hashes", "hashes"]) {
delete opts[opt];
}

View File

@ -6,6 +6,7 @@ from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.cmd import latent_preview
import torch
from comfy import utils
from comfy import node_helpers
class BasicScheduler:
@ -26,10 +27,11 @@ class BasicScheduler:
def get_sigmas(self, model, scheduler, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = int(steps/denoise)
model_management.load_models_gpu([model])
sigmas = samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
sigmas = samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):]
return (sigmas, )
@ -162,6 +164,9 @@ class FlipSigmas:
FUNCTION = "get_sigmas"
def get_sigmas(self, sigmas):
if len(sigmas) == 0:
return (sigmas,)
sigmas = sigmas.flip(0)
if sigmas[0] == 0:
sigmas[0] = 0.0001
@ -334,6 +339,24 @@ class SamplerDPMAdaptative:
"s_noise":s_noise })
return (sampler, )
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
def generate_noise(self, input_latent):
latent_image = input_latent["samples"]
return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
class Noise_RandomNoise:
def __init__(self, seed):
self.seed = seed
def generate_noise(self, input_latent):
latent_image = input_latent["samples"]
batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
class SamplerCustom:
@classmethod
def INPUT_TYPES(s):
@ -361,10 +384,9 @@ class SamplerCustom:
latent = latent_image
latent_image = latent["samples"]
if not add_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
noise = Noise_EmptyNoise().generate_noise(latent)
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = sample.prepare_noise(latent_image, noise_seed, batch_inds)
noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
noise_mask = None
if "noise_mask" in latent:
@ -385,6 +407,161 @@ class SamplerCustom:
out_denoised = out
return (out, out_denoised)
class Guider_Basic(comfy.samplers.CFGGuider):
def set_conds(self, positive):
self.inner_set_conds({"positive": positive})
class BasicGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"conditioning": ("CONDITIONING", ),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, conditioning):
guider = Guider_Basic(model)
guider.set_conds(conditioning)
return (guider,)
class CFGGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, positive, negative, cfg):
guider = comfy.samplers.CFGGuider(model)
guider.set_conds(positive, negative)
guider.set_cfg(cfg)
return (guider,)
class Guider_DualCFG(comfy.samplers.CFGGuider):
def set_cfg(self, cfg1, cfg2):
self.cfg1 = cfg1
self.cfg2 = cfg2
def set_conds(self, positive, middle, negative):
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
def predict_noise(self, x, timestep, model_options={}, seed=None):
negative_cond = self.conds.get("negative", None)
middle_cond = self.conds.get("middle", None)
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
class DualCFGGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"cond1": ("CONDITIONING", ),
"cond2": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative)
return (guider,)
class DisableNoise:
@classmethod
def INPUT_TYPES(s):
return {"required":{
}
}
RETURN_TYPES = ("NOISE",)
FUNCTION = "get_noise"
CATEGORY = "sampling/custom_sampling/noise"
def get_noise(self):
return (Noise_EmptyNoise(),)
class RandomNoise(DisableNoise):
@classmethod
def INPUT_TYPES(s):
return {"required":{
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
}
}
def get_noise(self, noise_seed):
return (Noise_RandomNoise(noise_seed),)
class SamplerCustomAdvanced:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"noise": ("NOISE", ),
"guider": ("GUIDER", ),
"sampler": ("SAMPLER", ),
"sigmas": ("SIGMAS", ),
"latent_image": ("LATENT", ),
}
}
RETURN_TYPES = ("LATENT","LATENT")
RETURN_NAMES = ("output", "denoised_output")
FUNCTION = "sample"
CATEGORY = "sampling/custom_sampling"
def sample(self, noise, guider, sampler, sigmas, latent_image):
latent = latent_image
latent_image = latent["samples"]
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
x0_output = {}
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
samples = samples.to(comfy.model_management.intermediate_device())
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
out_denoised = latent.copy()
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
else:
out_denoised = out
return (out, out_denoised)
NODE_CLASS_MAPPINGS = {
"SamplerCustom": SamplerCustom,
"BasicScheduler": BasicScheduler,
@ -402,4 +579,11 @@ NODE_CLASS_MAPPINGS = {
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,
"BasicGuider": BasicGuider,
"RandomNoise": RandomNoise,
"DisableNoise": DisableNoise,
"SamplerCustomAdvanced": SamplerCustomAdvanced,
}

View File

@ -0,0 +1,45 @@
import torch
class InstructPixToPixConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/instructpix2pix"
def encode(self, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
concat_latent = vae.encode(pixels)
out_latent = {}
out_latent["samples"] = torch.zeros_like(concat_latent)
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
NODE_CLASS_MAPPINGS = {
"InstructPixToPixConditioning": InstructPixToPixConditioning,
}

View File

@ -1,8 +1,10 @@
from comfy import sd, utils
from comfy import model_base
from comfy import model_management
from comfy import model_sampling
from comfy.cmd import folder_paths
import torch
import json
import os
@ -188,6 +190,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
extra_keys = {}
_model_sampling = model.get_model_object("model_sampling")
if isinstance(_model_sampling, model_sampling.ModelSamplingContinuousEDM):
if isinstance(_model_sampling, model_sampling.V_PREDICTION):
extra_keys["edm_vpred.sigma_max"] = torch.tensor(_model_sampling.sigma_max).float()
extra_keys["edm_vpred.sigma_min"] = torch.tensor(_model_sampling.sigma_min).float()
if model.model.model_type == model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
@ -202,7 +211,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave:
def __init__(self):

View File

@ -1,8 +1,10 @@
import torch
from comfy import sample
from comfy import samplers
from comfy import sampler_helpers
#TODO: This node should be removed and replaced with one that uses the new Guider/SamplerCustomAdvanced.
class PerpNeg:
@classmethod
def INPUT_TYPES(s):
@ -17,7 +19,7 @@ class PerpNeg:
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()
nocond = sample.convert_cond(empty_conditioning)
nocond = sampler_helpers.convert_cond(empty_conditioning)
def cfg_function(args):
model = args["model"]
@ -29,7 +31,7 @@ class PerpNeg:
model_options = args["model_options"]
nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
(noise_pred_nocond,) = samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond

View File

@ -150,7 +150,7 @@ class SelfAttentionGuidance:
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
(sag,) = samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)