Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-04-08 10:02:37 -07:00
commit 034ffcea03
22 changed files with 691 additions and 284 deletions

View File

@ -139,11 +139,12 @@ class ControlBase:
return out return out
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device()) if control_model is not None:
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
@ -184,7 +185,9 @@ class ControlNet(ControlBase):
return self.control_merge(None, control, control_prev, output_dtype) return self.control_merge(None, control, control_prev, output_dtype)
def copy(self): def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c) self.copy_to(c)
return c return c

View File

@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2} code2idx = {"q": 0, "k": 1, "v": 2}
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
def cat_tensors(tensors):
x = 0
for t in tensors:
x += t.shape[0]
shape = [x] + list(tensors[0].shape)[1:]
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
x = 0
for t in tensors:
out[x:x + t.shape[0]] = t
x += t.shape[0]
return out
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {} new_state_dict = {}
@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
if None in tensors: if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
for k_pre, tensors in capture_qkv_bias.items(): for k_pre, tensors in capture_qkv_bias.items():
if None in tensors: if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
return new_state_dict return new_state_dict

View File

@ -10,6 +10,7 @@ from .ldm.cascade.stage_c import StageC
from .ldm.cascade.stage_b import StageB from .ldm.cascade.stage_b import StageB
from enum import Enum from enum import Enum
from . import utils from . import utils
from . import latent_formats
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
@ -66,7 +67,8 @@ class BaseModel(torch.nn.Module):
self.adm_channels = unet_config.get("adm_in_channels", None) self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None: if self.adm_channels is None:
self.adm_channels = 0 self.adm_channels = 0
self.inpaint_model = False
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name)) logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
@ -107,8 +109,7 @@ class BaseModel(torch.nn.Module):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
if self.inpaint_model: if len(self.concat_keys) > 0:
concat_keys = ("mask", "masked_image")
cond_concat = [] cond_concat = []
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
concat_latent_image = kwargs.get("concat_latent_image", None) concat_latent_image = kwargs.get("concat_latent_image", None)
@ -125,24 +126,16 @@ class BaseModel(torch.nn.Module):
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape): if denoise_mask is not None:
denoise_mask = denoise_mask[:,:1] if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]: if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image): for ck in self.concat_keys:
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
for ck in concat_keys:
if denoise_mask is not None: if denoise_mask is not None:
if ck == "mask": if ck == "mask":
cond_concat.append(denoise_mask.to(device)) cond_concat.append(denoise_mask.to(device))
@ -152,7 +145,7 @@ class BaseModel(torch.nn.Module):
if ck == "mask": if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise)) cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
out['c_concat'] = conds.CONDNoiseShape(data) out['c_concat'] = conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs) adm = self.encode_adm(**kwargs)
@ -220,7 +213,16 @@ class BaseModel(torch.nn.Module):
return unet_state_dict return unet_state_dict
def set_inpaint(self): def set_inpaint(self):
self.inpaint_model = True self.concat_keys = ("mask", "masked_image")
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
self.blank_inpaint_image_like = blank_inpaint_image_like
def memory_required(self, input_shape): def memory_required(self, input_shape):
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention(): if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
@ -471,6 +473,42 @@ class SD_X4Upscaler(BaseModel):
out['y'] = conds.CONDRegular(noise_level) out['y'] = conds.CONDRegular(noise_level)
return out return out
class IP2P:
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
if model_type == ModelType.V_PREDICTION_EDM:
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
class StableCascade_C(BaseModel): class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC) super().__init__(model_config, model_type, device=device, unet_model=StageC)

View File

@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
return unet_config return unet_config
def model_config_from_unet_config(unet_config): def model_config_from_unet_config(unet_config, state_dict=None):
for model_config in supported_models.models: for model_config in supported_models.models:
if model_config.matches(unet_config): if model_config.matches(unet_config, state_dict):
return model_config(unet_config) return model_config(unet_config)
logging.error("no match {}".format(unet_config)) logging.error("no match {}".format(unet_config))
@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix) unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config) model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
return supported_models_base.BASE(unet_config) return supported_models_base.BASE(unet_config)
else: else:
@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1], 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]} 'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS] supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True

View File

@ -396,6 +396,7 @@ def load_models_gpu(models, memory_required=0):
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required) extra_mem = max(inference_memory, memory_required)
models = set(models)
models_to_load = [] models_to_load = []
models_already_loaded = [] models_already_loaded = []
for x in models: for x in models:

View File

@ -150,6 +150,15 @@ class ModelPatcher:
def add_object_patch(self, name, obj): def add_object_patch(self, name, obj):
self.object_patches[name] = obj self.object_patches[name] = obj
def get_model_object(self, name):
if name in self.object_patches:
return self.object_patches[name]
else:
if name in self.object_patches_backup:
return self.object_patches_backup[name]
else:
return utils.get_attr(self.model, name)
def model_patches_to(self, device): def model_patches_to(self, device):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" in to: if "patches" in to:
@ -278,7 +287,7 @@ class ModelPatcher:
if weight_key in self.patches: if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self) m.weight_function = LowVramPatch(weight_key, self)
if bias_key in self.patches: if bias_key in self.patches:
m.bias_function = LowVramPatch(weight_key, self) m.bias_function = LowVramPatch(bias_key, self)
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
@ -462,4 +471,4 @@ class ModelPatcher:
for k in keys: for k in keys:
utils.set_attr(self.model, k, self.object_patches_backup[k]) utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {} self.object_patches_backup.clear()

10
comfy/node_helpers.py Normal file
View File

@ -0,0 +1,10 @@
def conditioning_set_values(conditioning, values={}):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
n[1][k] = values[k]
c.append(n)
return c

View File

@ -30,7 +30,7 @@ from ..model_downloader import get_filename_list_with_downloadable, get_or_downl
from ..nodes.common import MAX_RESOLUTION from ..nodes.common import MAX_RESOLUTION
from .. import controlnet from .. import controlnet
from ..open_exr import load_exr from ..open_exr import load_exr
from .. import node_helpers
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod
@ -140,13 +140,9 @@ class ConditioningSetArea:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength): def append(self, conditioning, width, height, x, y, strength):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
for t in conditioning: "strength": strength,
n = [t[0], t[1].copy()] "set_area_to_bounds": False})
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
c.append(n)
return (c, ) return (c, )
class ConditioningSetAreaPercentage: class ConditioningSetAreaPercentage:
@ -165,13 +161,9 @@ class ConditioningSetAreaPercentage:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength): def append(self, conditioning, width, height, x, y, strength):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
for t in conditioning: "strength": strength,
n = [t[0], t[1].copy()] "set_area_to_bounds": False})
n[1]['area'] = ("percentage", height, width, y, x)
n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
c.append(n)
return (c, ) return (c, )
class ConditioningSetAreaStrength: class ConditioningSetAreaStrength:
@ -186,11 +178,7 @@ class ConditioningSetAreaStrength:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, strength): def append(self, conditioning, strength):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['strength'] = strength
c.append(n)
return (c, ) return (c, )
@ -208,19 +196,15 @@ class ConditioningSetMask:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, mask, set_cond_area, strength): def append(self, conditioning, mask, set_cond_area, strength):
c = []
set_area_to_bounds = False set_area_to_bounds = False
if set_cond_area != "default": if set_cond_area != "default":
set_area_to_bounds = True set_area_to_bounds = True
if len(mask.shape) < 3: if len(mask.shape) < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
for t in conditioning:
n = [t[0], t[1].copy()] c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
_, h, w = mask.shape "set_area_to_bounds": set_area_to_bounds,
n[1]['mask'] = mask "mask_strength": strength})
n[1]['set_area_to_bounds'] = set_area_to_bounds
n[1]['mask_strength'] = strength
c.append(n)
return (c, ) return (c, )
class ConditioningZeroOut: class ConditioningZeroOut:
@ -255,13 +239,8 @@ class ConditioningSetTimestepRange:
CATEGORY = "advanced/conditioning" CATEGORY = "advanced/conditioning"
def set_range(self, conditioning, start, end): def set_range(self, conditioning, start, end):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
for t in conditioning: "end_percent": end})
d = t[1].copy()
d['start_percent'] = start
d['end_percent'] = end
n = [t[0], d]
c.append(n)
return (c, ) return (c, )
class VAEDecode: class VAEDecode:
@ -402,13 +381,8 @@ class InpaintModelConditioning:
out = [] out = []
for conditioning in [positive, negative]: for conditioning in [positive, negative]:
c = [] c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
for t in conditioning: "concat_mask": mask})
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c) out.append(c)
return (out[0], out[1], out_latent) return (out[0], out[1], out_latent)

View File

@ -8,14 +8,14 @@ import sys
import time import time
import types import types
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List from typing import Dict, List, Iterable
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
from . import base_nodes from . import base_nodes
from .package_typing import ExportedNodes from .package_typing import ExportedNodes
def _vanilla_load_importing_execute_prestartup_script(node_paths: List[str]) -> None: def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None:
def execute_script(script_path): def execute_script(script_path):
module_name = splitext(script_path)[0] module_name = splitext(script_path)[0]
try: try:
@ -121,7 +121,7 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
return exported_nodes return exported_nodes
def _vanilla_load_custom_nodes_2(node_paths: List[str]) -> ExportedNodes: def _vanilla_load_custom_nodes_2(node_paths: Iterable[str]) -> ExportedNodes:
base_node_names = set(base_nodes.NODE_CLASS_MAPPINGS.keys()) base_node_names = set(base_nodes.NODE_CLASS_MAPPINGS.keys())
node_import_times = [] node_import_times = []
exported_nodes = ExportedNodes() exported_nodes = ExportedNodes()
@ -192,6 +192,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
if is_git_repository: if is_git_repository:
node_paths += [abspath(join(potential_git_dir_parent, "custom_nodes"))] node_paths += [abspath(join(potential_git_dir_parent, "custom_nodes"))]
node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths)
_vanilla_load_importing_execute_prestartup_script(node_paths) _vanilla_load_importing_execute_prestartup_script(node_paths)
vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths) vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths)
return vanilla_custom_nodes return vanilla_custom_nodes

View File

@ -5,6 +5,7 @@ from . import utils
from . import conds from . import conds
import math import math
import numpy as np import numpy as np
import logging
def prepare_noise(latent_image, seed, noise_inds=None): def prepare_noise(latent_image, seed, noise_inds=None):
""" """
@ -25,94 +26,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
noises = torch.cat(noises, axis=0) noises = torch.cat(noises, axis=0)
return noises return noises
def prepare_mask(noise_mask, shape, device): def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
"""ensures noise mask is of proper dimensions""" logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") return model, positive, negative, noise_mask, []
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models): def cleanup_additional_models(models):
"""cleanup additional models that were loaded""" logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device
positive = convert_cond(positive)
negative = convert_cond(negative)
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, positive, negative, noise_mask, models
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) sampler = samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
noise = noise.to(model.load_device) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
latent_image = latent_image.to(model.load_device)
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(model_management.intermediate_device()) samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) samples = samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sigmas = sigmas.to(model.load_device)
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(model_management.intermediate_device()) samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples

77
comfy/sampler_helpers.py Normal file
View File

@ -0,0 +1,77 @@
import torch
from comfy import model_management
from comfy import utils
from comfy import conds
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, conds):
device = model.load_device
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, conds, models
def cleanup_models(conds, models):
cleanup_additional_models(models)
control_cleanup = []
for k in conds:
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))

View File

@ -5,6 +5,7 @@ import collections
from . import model_management from . import model_management
import math import math
import logging import logging
from . import sampler_helpers
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
@ -130,30 +131,23 @@ def cond_cat(c_list):
return out return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): def calc_cond_batch(model, conds, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in) out_conds = []
out_count = torch.ones_like(x_in) * 1e-37 out_counts = []
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = [] to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)] for i in range(len(conds)):
if uncond is not None: out_conds.append(torch.zeros_like(x_in))
for x in uncond: out_counts.append(torch.ones_like(x_in) * 1e-37)
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)] cond = conds[i]
if cond is not None:
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, i)]
while len(to_run) > 0: while len(to_run) > 0:
first = to_run[0] first = to_run[0]
@ -225,74 +219,66 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else: else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks): for o in range(batch_chunks):
if cond_or_uncond[o] == COND: cond_index = cond_or_uncond[o]
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count for i in range(len(out_conds)):
del out_count out_conds[i] /= out_counts[i]
out_uncond /= out_uncond_count
del out_uncond_count return out_conds
return out_cond, out_uncond
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None uncond_ = None
else: else:
uncond_ = uncond uncond_ = uncond
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) conds = [cond, uncond_]
if "sampler_cfg_function" in model_options: out = calc_cond_batch(model, conds, x, timestep, model_options)
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result class KSamplerX0Inpaint:
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model, sigmas): def __init__(self, model, sigmas):
super().__init__()
self.inner_model = model self.inner_model = model
self.sigmas = sigmas self.sigmas = sigmas
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None: if denoise_mask is not None:
if "denoise_mask_function" in model_options: if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None: if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask out = out * denoise_mask + self.latent_image * latent_mask
return out return out
def simple_scheduler(model, steps): def simple_scheduler(model_sampling, steps):
s = model.model_sampling s = model_sampling
sigs = [] sigs = []
ss = len(s.sigmas) / steps ss = len(s.sigmas) / steps
for x in range(steps): for x in range(steps):
@ -300,8 +286,8 @@ def simple_scheduler(model, steps):
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps): def ddim_scheduler(model_sampling, steps):
s = model.model_sampling s = model_sampling
sigs = [] sigs = []
ss = max(len(s.sigmas) // steps, 1) ss = max(len(s.sigmas) // steps, 1)
x = 1 x = 1
@ -312,8 +298,8 @@ def ddim_scheduler(model, steps):
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def normal_scheduler(model, steps, sgm=False, floor=False): def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
s = model.model_sampling s = model_sampling
start = s.timestep(s.sigma_max) start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min) end = s.timestep(s.sigma_min)
@ -574,59 +560,120 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options) return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
return model_denoise
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
positive = positive[:] for k in conds:
negative = negative[:] conds[k] = conds[k][:]
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) for k in conds:
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) calculate_start_end_timesteps(model, conds[k])
model_wrap = wrap_model(model)
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'): if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) for k in conds:
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for k in conds:
create_cond_with_same_area_if_none(negative, c) for c in conds[k]:
for c in negative: for kk in conds:
create_cond_with_same_area_if_none(positive, c) if k != kk:
create_cond_with_same_area_if_none(conds[kk], c)
pre_run_control(model, negative + positive) for k in conds:
pre_run_control(model, conds[k])
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) if "positive" in conds:
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) positive = conds["positive"]
for k in conds:
if k != "positive":
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} return conds
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) class CFGGuider:
return model.process_latent_out(samples.to(torch.float32)) def __init__(self, model_patcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
def set_conds(self, positive, negative):
self.inner_set_conds({"positive": positive, "negative": negative})
def set_cfg(self, cfg):
self.cfg = cfg
def inner_set_conds(self, conds):
for k in conds:
self.original_conds[k] = sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": self.model_options, "seed":seed}
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
self.inner_model, self.conds, self.loaded_models = sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.conds
del self.loaded_models
return output
def calculate_sigmas_scheduler(model, scheduler_name, steps): def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
def calculate_sigmas(model_sampling, scheduler_name, steps):
if scheduler_name == "karras": if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential": elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "normal": elif scheduler_name == "normal":
sigmas = normal_scheduler(model, steps) sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple": elif scheduler_name == "simple":
sigmas = simple_scheduler(model, steps) sigmas = simple_scheduler(model_sampling, steps)
elif scheduler_name == "ddim_uniform": elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model, steps) sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True) sigmas = normal_scheduler(model_sampling, steps, sgm=True)
else: else:
logging.error("error invalid scheduler {}".format(scheduler_name)) logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas return sigmas
@ -668,7 +715,7 @@ class KSampler:
steps += 1 steps += 1
discard_penultimate_sigma = True discard_penultimate_sigma = True
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)
if discard_penultimate_sigma: if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
@ -679,9 +726,12 @@ class KSampler:
if denoise is None or denoise > 0.9999: if denoise is None or denoise > 0.9999:
self.sigmas = self.calculate_sigmas(steps).to(self.device) self.sigmas = self.calculate_sigmas(steps).to(self.device)
else: else:
new_steps = int(steps/denoise) if denoise <= 0.0:
sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = torch.FloatTensor([])
self.sigmas = sigmas[-(steps + 1):] else:
new_steps = int(steps/denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
if sigmas is None: if sigmas is None:

View File

@ -600,7 +600,7 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None clip_sd = None
load_models = [model] load_models = [model]
if clip is not None: if clip is not None:
@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
model_management.load_models_gpu(load_models) model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]
utils.save_torch_file(sd, output_path, metadata=metadata) utils.save_torch_file(sd, output_path, metadata=metadata)

View File

@ -174,6 +174,11 @@ class SDXL(supported_models_base.BASE):
self.sampling_settings["sigma_max"] = 80.0 self.sampling_settings["sigma_max"] = 80.0
self.sampling_settings["sigma_min"] = 0.002 self.sampling_settings["sigma_min"] = 0.002
return model_base.ModelType.EDM return model_base.ModelType.EDM
elif "edm_vpred.sigma_max" in state_dict:
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
if "edm_vpred.sigma_min" in state_dict:
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
return model_base.ModelType.V_PREDICTION_EDM
elif "v_pred" in state_dict: elif "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION return model_base.ModelType.V_PREDICTION
else: else:
@ -334,6 +339,11 @@ class Stable_Zero123(supported_models_base.BASE):
"num_head_channels": -1, "num_head_channels": -1,
} }
required_keys = {
"cc_projection.weight": None,
"cc_projection.bias": None,
}
clip_vision_prefix = "cond_stage_model.model.visual." clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
@ -439,6 +449,33 @@ class Stable_Cascade_B(Stable_Cascade_C):
out = model_base.StableCascade_B(self, device=device) out = model_base.StableCascade_B(self, device=device)
return out return out
class SD15_instructpix2pix(SD15):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SD15_instructpix2pix(self, device=device)
class SDXL_instructpix2pix(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -16,6 +16,8 @@ class BASE:
"num_head_channels": 64, "num_head_channels": 64,
} }
required_keys = {}
clip_prefix = [] clip_prefix = []
clip_vision_prefix = None clip_vision_prefix = None
noise_aug_config = None noise_aug_config = None
@ -28,10 +30,14 @@ class BASE:
manual_cast_dtype = None manual_cast_dtype = None
@classmethod @classmethod
def matches(s, unet_config): def matches(s, unet_config, state_dict=None):
for k in s.unet_config: for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]: if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False return False
if state_dict is not None:
for k in s.required_keys:
if k not in state_dict:
return False
return True return True
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):

View File

@ -7247,7 +7247,7 @@ LGraphNode.prototype.executeAction = function(action)
//create links //create links
for (var i = 0; i < clipboard_info.links.length; ++i) { for (var i = 0; i < clipboard_info.links.length; ++i) {
var link_info = clipboard_info.links[i]; var link_info = clipboard_info.links[i];
var origin_node; var origin_node = undefined;
var origin_node_relative_id = link_info[0]; var origin_node_relative_id = link_info[0];
if (origin_node_relative_id != null) { if (origin_node_relative_id != null) {
origin_node = nodes[origin_node_relative_id]; origin_node = nodes[origin_node_relative_id];

View File

@ -170,9 +170,12 @@ export async function importA1111(graph, parameters) {
const opts = parameters const opts = parameters
.substr(p) .substr(p)
.split("\n")[1] .split("\n")[1]
.split(",") .match(new RegExp("\\s*([^:]+:\\s*([^\"\\{].*?|\".*?\"|\\{.*?\\}))\\s*(,|$)", "g"))
.reduce((p, n) => { .reduce((p, n) => {
const s = n.split(":"); const s = n.split(":");
if (s[1].endsWith(',')) {
s[1] = s[1].substr(0, s[1].length -1);
}
p[s[0].trim().toLowerCase()] = s[1].trim(); p[s[0].trim().toLowerCase()] = s[1].trim();
return p; return p;
}, {}); }, {});
@ -191,6 +194,7 @@ export async function importA1111(graph, parameters) {
const vaeLoaderNode = LiteGraph.createNode("VAELoader"); const vaeLoaderNode = LiteGraph.createNode("VAELoader");
const saveNode = LiteGraph.createNode("SaveImage"); const saveNode = LiteGraph.createNode("SaveImage");
let hrSamplerNode = null; let hrSamplerNode = null;
let hrSteps = null;
const ceil64 = (v) => Math.ceil(v / 64) * 64; const ceil64 = (v) => Math.ceil(v / 64) * 64;
@ -290,6 +294,9 @@ export async function importA1111(graph, parameters) {
model(v) { model(v) {
setWidgetValue(ckptNode, "ckpt_name", v, true); setWidgetValue(ckptNode, "ckpt_name", v, true);
}, },
"vae"(v) {
setWidgetValue(vaeLoaderNode, "vae_name", v, true);
},
"cfg scale"(v) { "cfg scale"(v) {
setWidgetValue(samplerNode, "cfg", +v); setWidgetValue(samplerNode, "cfg", +v);
}, },
@ -316,6 +323,7 @@ export async function importA1111(graph, parameters) {
const h = ceil64(+wxh[1]); const h = ceil64(+wxh[1]);
const hrUp = popOpt("hires upscale"); const hrUp = popOpt("hires upscale");
const hrSz = popOpt("hires resize"); const hrSz = popOpt("hires resize");
hrSteps = popOpt("hires steps");
let hrMethod = popOpt("hires upscaler"); let hrMethod = popOpt("hires upscaler");
setWidgetValue(imageNode, "width", w); setWidgetValue(imageNode, "width", w);
@ -398,7 +406,7 @@ export async function importA1111(graph, parameters) {
} }
if (hrSamplerNode) { if (hrSamplerNode) {
setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value); setWidgetValue(hrSamplerNode, "steps", hrSteps? +hrSteps : getWidget(samplerNode, "steps").value);
setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value); setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value);
setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value); setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value);
setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value); setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value);
@ -415,7 +423,7 @@ export async function importA1111(graph, parameters) {
graph.arrange(); graph.arrange();
for (const opt of ["model hash", "ensd"]) { for (const opt of ["model hash", "ensd", "version", "vae hash", "ti hashes", "lora hashes", "hashes"]) {
delete opts[opt]; delete opts[opt];
} }

View File

@ -6,6 +6,7 @@ from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.cmd import latent_preview from comfy.cmd import latent_preview
import torch import torch
from comfy import utils from comfy import utils
from comfy import node_helpers
class BasicScheduler: class BasicScheduler:
@ -26,10 +27,11 @@ class BasicScheduler:
def get_sigmas(self, model, scheduler, steps, denoise): def get_sigmas(self, model, scheduler, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = int(steps/denoise) total_steps = int(steps/denoise)
model_management.load_models_gpu([model]) sigmas = samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
sigmas = samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):] sigmas = sigmas[-(steps + 1):]
return (sigmas, ) return (sigmas, )
@ -162,6 +164,9 @@ class FlipSigmas:
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
def get_sigmas(self, sigmas): def get_sigmas(self, sigmas):
if len(sigmas) == 0:
return (sigmas,)
sigmas = sigmas.flip(0) sigmas = sigmas.flip(0)
if sigmas[0] == 0: if sigmas[0] == 0:
sigmas[0] = 0.0001 sigmas[0] = 0.0001
@ -334,6 +339,24 @@ class SamplerDPMAdaptative:
"s_noise":s_noise }) "s_noise":s_noise })
return (sampler, ) return (sampler, )
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
def generate_noise(self, input_latent):
latent_image = input_latent["samples"]
return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
class Noise_RandomNoise:
def __init__(self, seed):
self.seed = seed
def generate_noise(self, input_latent):
latent_image = input_latent["samples"]
batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
class SamplerCustom: class SamplerCustom:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -361,10 +384,9 @@ class SamplerCustom:
latent = latent_image latent = latent_image
latent_image = latent["samples"] latent_image = latent["samples"]
if not add_noise: if not add_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = Noise_EmptyNoise().generate_noise(latent)
else: else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
noise = sample.prepare_noise(latent_image, noise_seed, batch_inds)
noise_mask = None noise_mask = None
if "noise_mask" in latent: if "noise_mask" in latent:
@ -385,6 +407,161 @@ class SamplerCustom:
out_denoised = out out_denoised = out
return (out, out_denoised) return (out, out_denoised)
class Guider_Basic(comfy.samplers.CFGGuider):
def set_conds(self, positive):
self.inner_set_conds({"positive": positive})
class BasicGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"conditioning": ("CONDITIONING", ),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, conditioning):
guider = Guider_Basic(model)
guider.set_conds(conditioning)
return (guider,)
class CFGGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, positive, negative, cfg):
guider = comfy.samplers.CFGGuider(model)
guider.set_conds(positive, negative)
guider.set_cfg(cfg)
return (guider,)
class Guider_DualCFG(comfy.samplers.CFGGuider):
def set_cfg(self, cfg1, cfg2):
self.cfg1 = cfg1
self.cfg2 = cfg2
def set_conds(self, positive, middle, negative):
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
def predict_noise(self, x, timestep, model_options={}, seed=None):
negative_cond = self.conds.get("negative", None)
middle_cond = self.conds.get("middle", None)
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
class DualCFGGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"cond1": ("CONDITIONING", ),
"cond2": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative)
return (guider,)
class DisableNoise:
@classmethod
def INPUT_TYPES(s):
return {"required":{
}
}
RETURN_TYPES = ("NOISE",)
FUNCTION = "get_noise"
CATEGORY = "sampling/custom_sampling/noise"
def get_noise(self):
return (Noise_EmptyNoise(),)
class RandomNoise(DisableNoise):
@classmethod
def INPUT_TYPES(s):
return {"required":{
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
}
}
def get_noise(self, noise_seed):
return (Noise_RandomNoise(noise_seed),)
class SamplerCustomAdvanced:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"noise": ("NOISE", ),
"guider": ("GUIDER", ),
"sampler": ("SAMPLER", ),
"sigmas": ("SIGMAS", ),
"latent_image": ("LATENT", ),
}
}
RETURN_TYPES = ("LATENT","LATENT")
RETURN_NAMES = ("output", "denoised_output")
FUNCTION = "sample"
CATEGORY = "sampling/custom_sampling"
def sample(self, noise, guider, sampler, sigmas, latent_image):
latent = latent_image
latent_image = latent["samples"]
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
x0_output = {}
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
samples = samples.to(comfy.model_management.intermediate_device())
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
out_denoised = latent.copy()
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
else:
out_denoised = out
return (out, out_denoised)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SamplerCustom": SamplerCustom, "SamplerCustom": SamplerCustom,
"BasicScheduler": BasicScheduler, "BasicScheduler": BasicScheduler,
@ -402,4 +579,11 @@ NODE_CLASS_MAPPINGS = {
"SamplerDPMAdaptative": SamplerDPMAdaptative, "SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas, "SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas, "FlipSigmas": FlipSigmas,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,
"BasicGuider": BasicGuider,
"RandomNoise": RandomNoise,
"DisableNoise": DisableNoise,
"SamplerCustomAdvanced": SamplerCustomAdvanced,
} }

View File

@ -0,0 +1,45 @@
import torch
class InstructPixToPixConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/instructpix2pix"
def encode(self, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
concat_latent = vae.encode(pixels)
out_latent = {}
out_latent["samples"] = torch.zeros_like(concat_latent)
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
NODE_CLASS_MAPPINGS = {
"InstructPixToPixConditioning": InstructPixToPixConditioning,
}

View File

@ -1,8 +1,10 @@
from comfy import sd, utils from comfy import sd, utils
from comfy import model_base from comfy import model_base
from comfy import model_management from comfy import model_management
from comfy import model_sampling
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
import torch
import json import json
import os import os
@ -188,6 +190,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting" # "v2-inpainting"
extra_keys = {}
_model_sampling = model.get_model_object("model_sampling")
if isinstance(_model_sampling, model_sampling.ModelSamplingContinuousEDM):
if isinstance(_model_sampling, model_sampling.V_PREDICTION):
extra_keys["edm_vpred.sigma_max"] = torch.tensor(_model_sampling.sigma_max).float()
extra_keys["edm_vpred.sigma_min"] = torch.tensor(_model_sampling.sigma_min).float()
if model.model.model_type == model_base.ModelType.EPS: if model.model.model_type == model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon" metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == model_base.ModelType.V_PREDICTION: elif model.model.model_type == model_base.ModelType.V_PREDICTION:
@ -202,7 +211,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave: class CheckpointSave:
def __init__(self): def __init__(self):

View File

@ -1,8 +1,10 @@
import torch import torch
from comfy import sample from comfy import sample
from comfy import samplers from comfy import samplers
from comfy import sampler_helpers
#TODO: This node should be removed and replaced with one that uses the new Guider/SamplerCustomAdvanced.
class PerpNeg: class PerpNeg:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -17,7 +19,7 @@ class PerpNeg:
def patch(self, model, empty_conditioning, neg_scale): def patch(self, model, empty_conditioning, neg_scale):
m = model.clone() m = model.clone()
nocond = sample.convert_cond(empty_conditioning) nocond = sampler_helpers.convert_cond(empty_conditioning)
def cfg_function(args): def cfg_function(args):
model = args["model"] model = args["model"]
@ -29,7 +31,7 @@ class PerpNeg:
model_options = args["model_options"] model_options = args["model_options"]
nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) (noise_pred_nocond,) = samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond neg = noise_pred_neg - noise_pred_nocond

View File

@ -150,7 +150,7 @@ class SelfAttentionGuidance:
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred degraded_noised = degraded + x - uncond_pred
# call into the UNet # call into the UNet
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) (sag,) = samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)