mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
034ffcea03
@ -139,11 +139,12 @@ class ControlBase:
|
||||
return out
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.model_sampling_current = None
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
@ -184,7 +185,9 @@ class ControlNet(ControlBase):
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
|
||||
@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
||||
def cat_tensors(tensors):
|
||||
x = 0
|
||||
for t in tensors:
|
||||
x += t.shape[0]
|
||||
|
||||
shape = [x] + list(tensors[0].shape)[1:]
|
||||
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
||||
|
||||
x = 0
|
||||
for t in tensors:
|
||||
out[x:x + t.shape[0]] = t
|
||||
x += t.shape[0]
|
||||
|
||||
return out
|
||||
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
new_state_dict = {}
|
||||
@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
||||
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
|
||||
|
||||
for k_pre, tensors in capture_qkv_bias.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
||||
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from .ldm.cascade.stage_c import StageC
|
||||
from .ldm.cascade.stage_b import StageB
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
from . import latent_formats
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
@ -66,7 +67,8 @@ class BaseModel(torch.nn.Module):
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
self.inpaint_model = False
|
||||
|
||||
self.concat_keys = ()
|
||||
logging.info("model_type {}".format(model_type.name))
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
|
||||
@ -107,8 +109,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
if self.inpaint_model:
|
||||
concat_keys = ("mask", "masked_image")
|
||||
if len(self.concat_keys) > 0:
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
concat_latent_image = kwargs.get("concat_latent_image", None)
|
||||
@ -125,24 +126,16 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
||||
|
||||
if len(denoise_mask.shape) == len(noise.shape):
|
||||
denoise_mask = denoise_mask[:,:1]
|
||||
if denoise_mask is not None:
|
||||
if len(denoise_mask.shape) == len(noise.shape):
|
||||
denoise_mask = denoise_mask[:,:1]
|
||||
|
||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
blank_image[:,0] *= 0.8223
|
||||
blank_image[:,1] *= -0.6876
|
||||
blank_image[:,2] *= 0.6364
|
||||
blank_image[:,3] *= 0.1380
|
||||
return blank_image
|
||||
|
||||
for ck in concat_keys:
|
||||
for ck in self.concat_keys:
|
||||
if denoise_mask is not None:
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask.to(device))
|
||||
@ -152,7 +145,7 @@ class BaseModel(torch.nn.Module):
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = conds.CONDNoiseShape(data)
|
||||
adm = self.encode_adm(**kwargs)
|
||||
@ -220,7 +213,16 @@ class BaseModel(torch.nn.Module):
|
||||
return unet_state_dict
|
||||
|
||||
def set_inpaint(self):
|
||||
self.inpaint_model = True
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
blank_image[:,0] *= 0.8223
|
||||
blank_image[:,1] *= -0.6876
|
||||
blank_image[:,2] *= 0.6364
|
||||
blank_image[:,3] *= 0.1380
|
||||
return blank_image
|
||||
self.blank_inpaint_image_like = blank_inpaint_image_like
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
|
||||
@ -471,6 +473,42 @@ class SD_X4Upscaler(BaseModel):
|
||||
out['y'] = conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
if image.shape[1:] != noise.shape[1:]:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
out['c_concat'] = conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
if model_type == ModelType.V_PREDICTION_EDM:
|
||||
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p
|
||||
else:
|
||||
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
|
||||
|
||||
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||
|
||||
@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
|
||||
return unet_config
|
||||
|
||||
def model_config_from_unet_config(unet_config):
|
||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
for model_config in supported_models.models:
|
||||
if model_config.matches(unet_config):
|
||||
if model_config.matches(unet_config, state_dict):
|
||||
return model_config(unet_config)
|
||||
|
||||
logging.error("no match {}".format(unet_config))
|
||||
@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||
model_config = model_config_from_unet_config(unet_config)
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
||||
@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS]
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
|
||||
@ -396,6 +396,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required)
|
||||
|
||||
models = set(models)
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
for x in models:
|
||||
|
||||
@ -150,6 +150,15 @@ class ModelPatcher:
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def get_model_object(self, name):
|
||||
if name in self.object_patches:
|
||||
return self.object_patches[name]
|
||||
else:
|
||||
if name in self.object_patches_backup:
|
||||
return self.object_patches_backup[name]
|
||||
else:
|
||||
return utils.get_attr(self.model, name)
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
@ -278,7 +287,7 @@ class ModelPatcher:
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(weight_key, self)
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
@ -462,4 +471,4 @@ class ModelPatcher:
|
||||
for k in keys:
|
||||
utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup = {}
|
||||
self.object_patches_backup.clear()
|
||||
|
||||
10
comfy/node_helpers.py
Normal file
10
comfy/node_helpers.py
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
def conditioning_set_values(conditioning, values={}):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
for k in values:
|
||||
n[1][k] = values[k]
|
||||
c.append(n)
|
||||
|
||||
return c
|
||||
@ -30,7 +30,7 @@ from ..model_downloader import get_filename_list_with_downloadable, get_or_downl
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
from ..open_exr import load_exr
|
||||
|
||||
from .. import node_helpers
|
||||
|
||||
class CLIPTextEncode:
|
||||
@classmethod
|
||||
@ -140,13 +140,9 @@ class ConditioningSetArea:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, width, height, x, y, strength):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
||||
n[1]['strength'] = strength
|
||||
n[1]['set_area_to_bounds'] = False
|
||||
c.append(n)
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
|
||||
"strength": strength,
|
||||
"set_area_to_bounds": False})
|
||||
return (c, )
|
||||
|
||||
class ConditioningSetAreaPercentage:
|
||||
@ -165,13 +161,9 @@ class ConditioningSetAreaPercentage:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, width, height, x, y, strength):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['area'] = ("percentage", height, width, y, x)
|
||||
n[1]['strength'] = strength
|
||||
n[1]['set_area_to_bounds'] = False
|
||||
c.append(n)
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
|
||||
"strength": strength,
|
||||
"set_area_to_bounds": False})
|
||||
return (c, )
|
||||
|
||||
class ConditioningSetAreaStrength:
|
||||
@ -186,11 +178,7 @@ class ConditioningSetAreaStrength:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, strength):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['strength'] = strength
|
||||
c.append(n)
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
|
||||
return (c, )
|
||||
|
||||
|
||||
@ -208,19 +196,15 @@ class ConditioningSetMask:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, mask, set_cond_area, strength):
|
||||
c = []
|
||||
set_area_to_bounds = False
|
||||
if set_cond_area != "default":
|
||||
set_area_to_bounds = True
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
_, h, w = mask.shape
|
||||
n[1]['mask'] = mask
|
||||
n[1]['set_area_to_bounds'] = set_area_to_bounds
|
||||
n[1]['mask_strength'] = strength
|
||||
c.append(n)
|
||||
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
|
||||
"set_area_to_bounds": set_area_to_bounds,
|
||||
"mask_strength": strength})
|
||||
return (c, )
|
||||
|
||||
class ConditioningZeroOut:
|
||||
@ -255,13 +239,8 @@ class ConditioningSetTimestepRange:
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def set_range(self, conditioning, start, end):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d['start_percent'] = start
|
||||
d['end_percent'] = end
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
|
||||
"end_percent": end})
|
||||
return (c, )
|
||||
|
||||
class VAEDecode:
|
||||
@ -402,13 +381,8 @@ class InpaintModelConditioning:
|
||||
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d["concat_latent_image"] = concat_latent
|
||||
d["concat_mask"] = mask
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
|
||||
"concat_mask": mask})
|
||||
out.append(c)
|
||||
return (out[0], out[1], out_latent)
|
||||
|
||||
|
||||
@ -8,14 +8,14 @@ import sys
|
||||
import time
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Iterable
|
||||
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
|
||||
|
||||
from . import base_nodes
|
||||
from .package_typing import ExportedNodes
|
||||
|
||||
|
||||
def _vanilla_load_importing_execute_prestartup_script(node_paths: List[str]) -> None:
|
||||
def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None:
|
||||
def execute_script(script_path):
|
||||
module_name = splitext(script_path)[0]
|
||||
try:
|
||||
@ -121,7 +121,7 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
|
||||
return exported_nodes
|
||||
|
||||
|
||||
def _vanilla_load_custom_nodes_2(node_paths: List[str]) -> ExportedNodes:
|
||||
def _vanilla_load_custom_nodes_2(node_paths: Iterable[str]) -> ExportedNodes:
|
||||
base_node_names = set(base_nodes.NODE_CLASS_MAPPINGS.keys())
|
||||
node_import_times = []
|
||||
exported_nodes = ExportedNodes()
|
||||
@ -192,6 +192,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
||||
if is_git_repository:
|
||||
node_paths += [abspath(join(potential_git_dir_parent, "custom_nodes"))]
|
||||
|
||||
node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths)
|
||||
|
||||
_vanilla_load_importing_execute_prestartup_script(node_paths)
|
||||
vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths)
|
||||
return vanilla_custom_nodes
|
||||
|
||||
@ -5,6 +5,7 @@ from . import utils
|
||||
from . import conds
|
||||
import math
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
@ -25,94 +26,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
temp = c[1].copy()
|
||||
model_conds = temp.get("model_conds", {})
|
||||
if c[0] is not None:
|
||||
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(positive, negative, dtype):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
||||
|
||||
inference_memory = 0
|
||||
control_models = []
|
||||
for m in control_nets:
|
||||
control_models += m.get_models()
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
return models, inference_memory
|
||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
|
||||
return model, positive, negative, noise_mask, []
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
"""cleanup additional models that were loaded"""
|
||||
for m in models:
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||
device = model.load_device
|
||||
positive = convert_cond(positive)
|
||||
negative = convert_cond(negative)
|
||||
|
||||
if noise_mask is not None:
|
||||
noise_mask = prepare_mask(noise_mask, noise_shape, device)
|
||||
|
||||
real_model = None
|
||||
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
|
||||
model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, positive, negative, noise_mask, models
|
||||
|
||||
logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
|
||||
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
|
||||
sampler = samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
noise = noise.to(model.load_device)
|
||||
latent_image = latent_image.to(model.load_device)
|
||||
|
||||
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(model_management.intermediate_device())
|
||||
|
||||
cleanup_additional_models(models)
|
||||
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
|
||||
noise = noise.to(model.load_device)
|
||||
latent_image = latent_image.to(model.load_device)
|
||||
sigmas = sigmas.to(model.load_device)
|
||||
|
||||
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(model_management.intermediate_device())
|
||||
cleanup_additional_models(models)
|
||||
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||
return samples
|
||||
|
||||
|
||||
77
comfy/sampler_helpers.py
Normal file
77
comfy/sampler_helpers.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from comfy import model_management
|
||||
from comfy import utils
|
||||
from comfy import conds
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
def get_models_from_cond(cond, model_type):
|
||||
models = []
|
||||
for c in cond:
|
||||
if model_type in c:
|
||||
models += [c[model_type]]
|
||||
return models
|
||||
|
||||
def convert_cond(cond):
|
||||
out = []
|
||||
for c in cond:
|
||||
temp = c[1].copy()
|
||||
model_conds = temp.get("model_conds", {})
|
||||
if c[0] is not None:
|
||||
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
out.append(temp)
|
||||
return out
|
||||
|
||||
def get_additional_models(conds, dtype):
|
||||
"""loads additional models in conditioning"""
|
||||
cnets = []
|
||||
gligen = []
|
||||
|
||||
for k in conds:
|
||||
cnets += get_models_from_cond(conds[k], "control")
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
|
||||
control_nets = set(cnets)
|
||||
|
||||
inference_memory = 0
|
||||
control_models = []
|
||||
for m in control_nets:
|
||||
control_models += m.get_models()
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
return models, inference_memory
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
"""cleanup additional models that were loaded"""
|
||||
for m in models:
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
|
||||
def prepare_sampling(model, noise_shape, conds):
|
||||
device = model.load_device
|
||||
real_model = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
def cleanup_models(conds, models):
|
||||
cleanup_additional_models(models)
|
||||
|
||||
control_cleanup = []
|
||||
for k in conds:
|
||||
control_cleanup += get_models_from_cond(conds[k], "control")
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
@ -5,6 +5,7 @@ import collections
|
||||
from . import model_management
|
||||
import math
|
||||
import logging
|
||||
from . import sampler_helpers
|
||||
|
||||
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
|
||||
|
||||
@ -130,30 +131,23 @@ def cond_cat(c_list):
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
out_uncond = torch.zeros_like(x_in)
|
||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
|
||||
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
if uncond is not None:
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
to_run += [(p, UNCOND)]
|
||||
cond = conds[i]
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, i)]
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
@ -225,74 +219,66 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
if cond_or_uncond[o] == COND:
|
||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
else:
|
||||
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
del mult
|
||||
cond_index = cond_or_uncond[o]
|
||||
out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
|
||||
out_cond /= out_count
|
||||
del out_count
|
||||
out_uncond /= out_uncond_count
|
||||
del out_uncond_count
|
||||
return out_cond, out_uncond
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
|
||||
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
|
||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
"sigma": timestep, "model_options": model_options, "input": x}
|
||||
cfg_result = fn(args)
|
||||
|
||||
return cfg_result
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns denoised
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||
uncond_ = None
|
||||
else:
|
||||
uncond_ = uncond
|
||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||
uncond_ = None
|
||||
else:
|
||||
uncond_ = uncond
|
||||
|
||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
conds = [cond, uncond_]
|
||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
||||
|
||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
"sigma": timestep, "model_options": model_options, "input": x}
|
||||
cfg_result = fn(args)
|
||||
|
||||
return cfg_result
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
return out
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
class KSamplerX0Inpaint:
|
||||
def __init__(self, model, sigmas):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.sigmas = sigmas
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||
if denoise_mask is not None:
|
||||
if "denoise_mask_function" in model_options:
|
||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
return out
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
def simple_scheduler(model_sampling, steps):
|
||||
s = model_sampling
|
||||
sigs = []
|
||||
ss = len(s.sigmas) / steps
|
||||
for x in range(steps):
|
||||
@ -300,8 +286,8 @@ def simple_scheduler(model, steps):
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def ddim_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
def ddim_scheduler(model_sampling, steps):
|
||||
s = model_sampling
|
||||
sigs = []
|
||||
ss = max(len(s.sigmas) // steps, 1)
|
||||
x = 1
|
||||
@ -312,8 +298,8 @@ def ddim_scheduler(model, steps):
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def normal_scheduler(model, steps, sgm=False, floor=False):
|
||||
s = model.model_sampling
|
||||
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||
s = model_sampling
|
||||
start = s.timestep(s.sigma_max)
|
||||
end = s.timestep(s.sigma_min)
|
||||
|
||||
@ -574,59 +560,120 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
def wrap_model(model):
|
||||
model_denoise = CFGNoisePredictor(model)
|
||||
return model_denoise
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
positive = positive[:]
|
||||
negative = negative[:]
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||
for k in conds:
|
||||
conds[k] = conds[k][:]
|
||||
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
|
||||
|
||||
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
|
||||
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
|
||||
|
||||
model_wrap = wrap_model(model)
|
||||
|
||||
calculate_start_end_timesteps(model, negative)
|
||||
calculate_start_end_timesteps(model, positive)
|
||||
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||
latent_image = model.process_latent_in(latent_image)
|
||||
for k in conds:
|
||||
calculate_start_end_timesteps(model, conds[k])
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
for k in conds:
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
create_cond_with_same_area_if_none(negative, c)
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
for k in conds:
|
||||
for c in conds[k]:
|
||||
for kk in conds:
|
||||
if k != kk:
|
||||
create_cond_with_same_area_if_none(conds[kk], c)
|
||||
|
||||
pre_run_control(model, negative + positive)
|
||||
for k in conds:
|
||||
pre_run_control(model, conds[k])
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
if "positive" in conds:
|
||||
positive = conds["positive"]
|
||||
for k in conds:
|
||||
if k != "positive":
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||
return conds
|
||||
|
||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return model.process_latent_out(samples.to(torch.float32))
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher):
|
||||
self.model_patcher = model_patcher
|
||||
self.model_options = model_patcher.model_options
|
||||
self.original_conds = {}
|
||||
self.cfg = 1.0
|
||||
|
||||
def set_conds(self, positive, negative):
|
||||
self.inner_set_conds({"positive": positive, "negative": negative})
|
||||
|
||||
def set_cfg(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
def inner_set_conds(self, conds):
|
||||
for k in conds:
|
||||
self.original_conds[k] = sampler_helpers.convert_cond(conds[k])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.predict_noise(*args, **kwargs)
|
||||
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||
latent_image = self.inner_model.process_latent_in(latent_image)
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
|
||||
extra_args = {"model_options": self.model_options, "seed":seed}
|
||||
|
||||
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
|
||||
self.inner_model, self.conds, self.loaded_models = sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
del self.conds
|
||||
del self.loaded_models
|
||||
return output
|
||||
|
||||
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
cfg_guider = CFGGuider(model)
|
||||
cfg_guider.set_conds(positive, negative)
|
||||
cfg_guider.set_cfg(cfg)
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
|
||||
|
||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = normal_scheduler(model, steps)
|
||||
sigmas = normal_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "simple":
|
||||
sigmas = simple_scheduler(model, steps)
|
||||
sigmas = simple_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(model, steps)
|
||||
sigmas = ddim_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
else:
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
return sigmas
|
||||
@ -668,7 +715,7 @@ class KSampler:
|
||||
steps += 1
|
||||
discard_penultimate_sigma = True
|
||||
|
||||
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps)
|
||||
sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)
|
||||
|
||||
if discard_penultimate_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
@ -679,9 +726,12 @@ class KSampler:
|
||||
if denoise is None or denoise > 0.9999:
|
||||
self.sigmas = self.calculate_sigmas(steps).to(self.device)
|
||||
else:
|
||||
new_steps = int(steps/denoise)
|
||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||
self.sigmas = sigmas[-(steps + 1):]
|
||||
if denoise <= 0.0:
|
||||
self.sigmas = torch.FloatTensor([])
|
||||
else:
|
||||
new_steps = int(steps/denoise)
|
||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||
self.sigmas = sigmas[-(steps + 1):]
|
||||
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas is None:
|
||||
|
||||
@ -600,7 +600,7 @@ def load_unet(unet_path):
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
if clip is not None:
|
||||
@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
model_management.load_models_gpu(load_models)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||
|
||||
@ -174,6 +174,11 @@ class SDXL(supported_models_base.BASE):
|
||||
self.sampling_settings["sigma_max"] = 80.0
|
||||
self.sampling_settings["sigma_min"] = 0.002
|
||||
return model_base.ModelType.EDM
|
||||
elif "edm_vpred.sigma_max" in state_dict:
|
||||
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
|
||||
if "edm_vpred.sigma_min" in state_dict:
|
||||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||
return model_base.ModelType.V_PREDICTION_EDM
|
||||
elif "v_pred" in state_dict:
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
@ -334,6 +339,11 @@ class Stable_Zero123(supported_models_base.BASE):
|
||||
"num_head_channels": -1,
|
||||
}
|
||||
|
||||
required_keys = {
|
||||
"cc_projection.weight": None,
|
||||
"cc_projection.bias": None,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "cond_stage_model.model.visual."
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
@ -439,6 +449,33 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
||||
out = model_base.StableCascade_B(self, device=device)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(SD15):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": False,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SD15_instructpix2pix(self, device=device)
|
||||
|
||||
class SDXL_instructpix2pix(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -16,6 +16,8 @@ class BASE:
|
||||
"num_head_channels": 64,
|
||||
}
|
||||
|
||||
required_keys = {}
|
||||
|
||||
clip_prefix = []
|
||||
clip_vision_prefix = None
|
||||
noise_aug_config = None
|
||||
@ -28,10 +30,14 @@ class BASE:
|
||||
manual_cast_dtype = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config):
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
for k in s.unet_config:
|
||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||
return False
|
||||
if state_dict is not None:
|
||||
for k in s.required_keys:
|
||||
if k not in state_dict:
|
||||
return False
|
||||
return True
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
|
||||
@ -7247,7 +7247,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
//create links
|
||||
for (var i = 0; i < clipboard_info.links.length; ++i) {
|
||||
var link_info = clipboard_info.links[i];
|
||||
var origin_node;
|
||||
var origin_node = undefined;
|
||||
var origin_node_relative_id = link_info[0];
|
||||
if (origin_node_relative_id != null) {
|
||||
origin_node = nodes[origin_node_relative_id];
|
||||
|
||||
@ -170,9 +170,12 @@ export async function importA1111(graph, parameters) {
|
||||
const opts = parameters
|
||||
.substr(p)
|
||||
.split("\n")[1]
|
||||
.split(",")
|
||||
.match(new RegExp("\\s*([^:]+:\\s*([^\"\\{].*?|\".*?\"|\\{.*?\\}))\\s*(,|$)", "g"))
|
||||
.reduce((p, n) => {
|
||||
const s = n.split(":");
|
||||
if (s[1].endsWith(',')) {
|
||||
s[1] = s[1].substr(0, s[1].length -1);
|
||||
}
|
||||
p[s[0].trim().toLowerCase()] = s[1].trim();
|
||||
return p;
|
||||
}, {});
|
||||
@ -191,6 +194,7 @@ export async function importA1111(graph, parameters) {
|
||||
const vaeLoaderNode = LiteGraph.createNode("VAELoader");
|
||||
const saveNode = LiteGraph.createNode("SaveImage");
|
||||
let hrSamplerNode = null;
|
||||
let hrSteps = null;
|
||||
|
||||
const ceil64 = (v) => Math.ceil(v / 64) * 64;
|
||||
|
||||
@ -290,6 +294,9 @@ export async function importA1111(graph, parameters) {
|
||||
model(v) {
|
||||
setWidgetValue(ckptNode, "ckpt_name", v, true);
|
||||
},
|
||||
"vae"(v) {
|
||||
setWidgetValue(vaeLoaderNode, "vae_name", v, true);
|
||||
},
|
||||
"cfg scale"(v) {
|
||||
setWidgetValue(samplerNode, "cfg", +v);
|
||||
},
|
||||
@ -316,6 +323,7 @@ export async function importA1111(graph, parameters) {
|
||||
const h = ceil64(+wxh[1]);
|
||||
const hrUp = popOpt("hires upscale");
|
||||
const hrSz = popOpt("hires resize");
|
||||
hrSteps = popOpt("hires steps");
|
||||
let hrMethod = popOpt("hires upscaler");
|
||||
|
||||
setWidgetValue(imageNode, "width", w);
|
||||
@ -398,7 +406,7 @@ export async function importA1111(graph, parameters) {
|
||||
}
|
||||
|
||||
if (hrSamplerNode) {
|
||||
setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value);
|
||||
setWidgetValue(hrSamplerNode, "steps", hrSteps? +hrSteps : getWidget(samplerNode, "steps").value);
|
||||
setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value);
|
||||
setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value);
|
||||
setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value);
|
||||
@ -415,7 +423,7 @@ export async function importA1111(graph, parameters) {
|
||||
|
||||
graph.arrange();
|
||||
|
||||
for (const opt of ["model hash", "ensd"]) {
|
||||
for (const opt of ["model hash", "ensd", "version", "vae hash", "ti hashes", "lora hashes", "hashes"]) {
|
||||
delete opts[opt];
|
||||
}
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
from comfy.cmd import latent_preview
|
||||
import torch
|
||||
from comfy import utils
|
||||
from comfy import node_helpers
|
||||
|
||||
|
||||
class BasicScheduler:
|
||||
@ -26,10 +27,11 @@ class BasicScheduler:
|
||||
def get_sigmas(self, model, scheduler, steps, denoise):
|
||||
total_steps = steps
|
||||
if denoise < 1.0:
|
||||
if denoise <= 0.0:
|
||||
return (torch.FloatTensor([]),)
|
||||
total_steps = int(steps/denoise)
|
||||
|
||||
model_management.load_models_gpu([model])
|
||||
sigmas = samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
|
||||
sigmas = samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
|
||||
sigmas = sigmas[-(steps + 1):]
|
||||
return (sigmas, )
|
||||
|
||||
@ -162,6 +164,9 @@ class FlipSigmas:
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, sigmas):
|
||||
if len(sigmas) == 0:
|
||||
return (sigmas,)
|
||||
|
||||
sigmas = sigmas.flip(0)
|
||||
if sigmas[0] == 0:
|
||||
sigmas[0] = 0.0001
|
||||
@ -334,6 +339,24 @@ class SamplerDPMAdaptative:
|
||||
"s_noise":s_noise })
|
||||
return (sampler, )
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
|
||||
def generate_noise(self, input_latent):
|
||||
latent_image = input_latent["samples"]
|
||||
return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
|
||||
|
||||
class Noise_RandomNoise:
|
||||
def __init__(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
def generate_noise(self, input_latent):
|
||||
latent_image = input_latent["samples"]
|
||||
batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
|
||||
return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
|
||||
|
||||
class SamplerCustom:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -361,10 +384,9 @@ class SamplerCustom:
|
||||
latent = latent_image
|
||||
latent_image = latent["samples"]
|
||||
if not add_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
noise = Noise_EmptyNoise().generate_noise(latent)
|
||||
else:
|
||||
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||
noise = sample.prepare_noise(latent_image, noise_seed, batch_inds)
|
||||
noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
@ -385,6 +407,161 @@ class SamplerCustom:
|
||||
out_denoised = out
|
||||
return (out, out_denoised)
|
||||
|
||||
class Guider_Basic(comfy.samplers.CFGGuider):
|
||||
def set_conds(self, positive):
|
||||
self.inner_set_conds({"positive": positive})
|
||||
|
||||
class BasicGuider:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"conditioning": ("CONDITIONING", ),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("GUIDER",)
|
||||
|
||||
FUNCTION = "get_guider"
|
||||
CATEGORY = "sampling/custom_sampling/guiders"
|
||||
|
||||
def get_guider(self, model, conditioning):
|
||||
guider = Guider_Basic(model)
|
||||
guider.set_conds(conditioning)
|
||||
return (guider,)
|
||||
|
||||
class CFGGuider:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("GUIDER",)
|
||||
|
||||
FUNCTION = "get_guider"
|
||||
CATEGORY = "sampling/custom_sampling/guiders"
|
||||
|
||||
def get_guider(self, model, positive, negative, cfg):
|
||||
guider = comfy.samplers.CFGGuider(model)
|
||||
guider.set_conds(positive, negative)
|
||||
guider.set_cfg(cfg)
|
||||
return (guider,)
|
||||
|
||||
class Guider_DualCFG(comfy.samplers.CFGGuider):
|
||||
def set_cfg(self, cfg1, cfg2):
|
||||
self.cfg1 = cfg1
|
||||
self.cfg2 = cfg2
|
||||
|
||||
def set_conds(self, positive, middle, negative):
|
||||
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
|
||||
self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
|
||||
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
negative_cond = self.conds.get("negative", None)
|
||||
middle_cond = self.conds.get("middle", None)
|
||||
|
||||
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
|
||||
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||
|
||||
class DualCFGGuider:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"cond1": ("CONDITIONING", ),
|
||||
"cond2": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("GUIDER",)
|
||||
|
||||
FUNCTION = "get_guider"
|
||||
CATEGORY = "sampling/custom_sampling/guiders"
|
||||
|
||||
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
|
||||
guider = Guider_DualCFG(model)
|
||||
guider.set_conds(cond1, cond2, negative)
|
||||
guider.set_cfg(cfg_conds, cfg_cond2_negative)
|
||||
return (guider,)
|
||||
|
||||
class DisableNoise:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":{
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("NOISE",)
|
||||
FUNCTION = "get_noise"
|
||||
CATEGORY = "sampling/custom_sampling/noise"
|
||||
|
||||
def get_noise(self):
|
||||
return (Noise_EmptyNoise(),)
|
||||
|
||||
|
||||
class RandomNoise(DisableNoise):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":{
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
}
|
||||
}
|
||||
|
||||
def get_noise(self, noise_seed):
|
||||
return (Noise_RandomNoise(noise_seed),)
|
||||
|
||||
|
||||
class SamplerCustomAdvanced:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"noise": ("NOISE", ),
|
||||
"guider": ("GUIDER", ),
|
||||
"sampler": ("SAMPLER", ),
|
||||
"sigmas": ("SIGMAS", ),
|
||||
"latent_image": ("LATENT", ),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT","LATENT")
|
||||
RETURN_NAMES = ("output", "denoised_output")
|
||||
|
||||
FUNCTION = "sample"
|
||||
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
def sample(self, noise, guider, sampler, sigmas, latent_image):
|
||||
latent = latent_image
|
||||
latent_image = latent["samples"]
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
x0_output = {}
|
||||
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
|
||||
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
if "x0" in x0_output:
|
||||
out_denoised = latent.copy()
|
||||
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
|
||||
else:
|
||||
out_denoised = out
|
||||
return (out, out_denoised)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SamplerCustom": SamplerCustom,
|
||||
"BasicScheduler": BasicScheduler,
|
||||
@ -402,4 +579,11 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
|
||||
"CFGGuider": CFGGuider,
|
||||
"DualCFGGuider": DualCFGGuider,
|
||||
"BasicGuider": BasicGuider,
|
||||
"RandomNoise": RandomNoise,
|
||||
"DisableNoise": DisableNoise,
|
||||
"SamplerCustomAdvanced": SamplerCustomAdvanced,
|
||||
}
|
||||
|
||||
45
comfy_extras/nodes/nodes_ip2p.py
Normal file
45
comfy_extras/nodes/nodes_ip2p.py
Normal file
@ -0,0 +1,45 @@
|
||||
import torch
|
||||
|
||||
class InstructPixToPixConditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"pixels": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/instructpix2pix"
|
||||
|
||||
def encode(self, positive, negative, pixels, vae):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||
|
||||
concat_latent = vae.encode(pixels)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = torch.zeros_like(concat_latent)
|
||||
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d["concat_latent_image"] = concat_latent
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
out.append(c)
|
||||
return (out[0], out[1], out_latent)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"InstructPixToPixConditioning": InstructPixToPixConditioning,
|
||||
}
|
||||
@ -1,8 +1,10 @@
|
||||
from comfy import sd, utils
|
||||
from comfy import model_base
|
||||
from comfy import model_management
|
||||
|
||||
from comfy import model_sampling
|
||||
from comfy.cmd import folder_paths
|
||||
|
||||
import torch
|
||||
import json
|
||||
import os
|
||||
|
||||
@ -188,6 +190,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
||||
# "v2-inpainting"
|
||||
|
||||
extra_keys = {}
|
||||
_model_sampling = model.get_model_object("model_sampling")
|
||||
if isinstance(_model_sampling, model_sampling.ModelSamplingContinuousEDM):
|
||||
if isinstance(_model_sampling, model_sampling.V_PREDICTION):
|
||||
extra_keys["edm_vpred.sigma_max"] = torch.tensor(_model_sampling.sigma_max).float()
|
||||
extra_keys["edm_vpred.sigma_min"] = torch.tensor(_model_sampling.sigma_min).float()
|
||||
|
||||
if model.model.model_type == model_base.ModelType.EPS:
|
||||
metadata["modelspec.predict_key"] = "epsilon"
|
||||
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
|
||||
@ -202,7 +211,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
|
||||
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||
|
||||
class CheckpointSave:
|
||||
def __init__(self):
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from comfy import sample
|
||||
from comfy import samplers
|
||||
from comfy import sampler_helpers
|
||||
|
||||
|
||||
#TODO: This node should be removed and replaced with one that uses the new Guider/SamplerCustomAdvanced.
|
||||
class PerpNeg:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -17,7 +19,7 @@ class PerpNeg:
|
||||
|
||||
def patch(self, model, empty_conditioning, neg_scale):
|
||||
m = model.clone()
|
||||
nocond = sample.convert_cond(empty_conditioning)
|
||||
nocond = sampler_helpers.convert_cond(empty_conditioning)
|
||||
|
||||
def cfg_function(args):
|
||||
model = args["model"]
|
||||
@ -29,7 +31,7 @@ class PerpNeg:
|
||||
model_options = args["model_options"]
|
||||
nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
|
||||
|
||||
(noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
|
||||
(noise_pred_nocond,) = samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
|
||||
|
||||
pos = noise_pred_pos - noise_pred_nocond
|
||||
neg = noise_pred_neg - noise_pred_nocond
|
||||
|
||||
@ -150,7 +150,7 @@ class SelfAttentionGuidance:
|
||||
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
||||
degraded_noised = degraded + x - uncond_pred
|
||||
# call into the UNet
|
||||
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
||||
(sag,) = samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
|
||||
return cfg_result + (degraded - sag) * sag_scale
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user