mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 00:00:26 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
034ffcea03
@ -139,11 +139,12 @@ class ControlBase:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
if control_model is not None:
|
||||||
|
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
@ -184,7 +185,9 @@ class ControlNet(ControlBase):
|
|||||||
return self.control_merge(None, control, control_prev, output_dtype)
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
|
c.control_model = self.control_model
|
||||||
|
c.control_model_wrapped = self.control_model_wrapped
|
||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|||||||
@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
||||||
|
def cat_tensors(tensors):
|
||||||
|
x = 0
|
||||||
|
for t in tensors:
|
||||||
|
x += t.shape[0]
|
||||||
|
|
||||||
|
shape = [x] + list(tensors[0].shape)[1:]
|
||||||
|
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
||||||
|
|
||||||
|
x = 0
|
||||||
|
for t in tensors:
|
||||||
|
out[x:x + t.shape[0]] = t
|
||||||
|
x += t.shape[0]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
|||||||
if None in tensors:
|
if None in tensors:
|
||||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||||
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
|
||||||
|
|
||||||
for k_pre, tensors in capture_qkv_bias.items():
|
for k_pre, tensors in capture_qkv_bias.items():
|
||||||
if None in tensors:
|
if None in tensors:
|
||||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||||
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
|
||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from .ldm.cascade.stage_c import StageC
|
|||||||
from .ldm.cascade.stage_b import StageB
|
from .ldm.cascade.stage_b import StageB
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from . import latent_formats
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
@ -66,7 +67,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
self.inpaint_model = False
|
|
||||||
|
self.concat_keys = ()
|
||||||
logging.info("model_type {}".format(model_type.name))
|
logging.info("model_type {}".format(model_type.name))
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
|
|
||||||
@ -107,8 +109,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
if self.inpaint_model:
|
if len(self.concat_keys) > 0:
|
||||||
concat_keys = ("mask", "masked_image")
|
|
||||||
cond_concat = []
|
cond_concat = []
|
||||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
concat_latent_image = kwargs.get("concat_latent_image", None)
|
concat_latent_image = kwargs.get("concat_latent_image", None)
|
||||||
@ -125,24 +126,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
|
||||||
|
|
||||||
if len(denoise_mask.shape) == len(noise.shape):
|
if denoise_mask is not None:
|
||||||
denoise_mask = denoise_mask[:,:1]
|
if len(denoise_mask.shape) == len(noise.shape):
|
||||||
|
denoise_mask = denoise_mask[:,:1]
|
||||||
|
|
||||||
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
|
||||||
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
if denoise_mask.shape[-2:] != noise.shape[-2:]:
|
||||||
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
|
||||||
|
|
||||||
def blank_inpaint_image_like(latent_image):
|
for ck in self.concat_keys:
|
||||||
blank_image = torch.ones_like(latent_image)
|
|
||||||
# these are the values for "zero" in pixel space translated to latent space
|
|
||||||
blank_image[:,0] *= 0.8223
|
|
||||||
blank_image[:,1] *= -0.6876
|
|
||||||
blank_image[:,2] *= 0.6364
|
|
||||||
blank_image[:,3] *= 0.1380
|
|
||||||
return blank_image
|
|
||||||
|
|
||||||
for ck in concat_keys:
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(denoise_mask.to(device))
|
cond_concat.append(denoise_mask.to(device))
|
||||||
@ -152,7 +145,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||||
elif ck == "masked_image":
|
elif ck == "masked_image":
|
||||||
cond_concat.append(blank_inpaint_image_like(noise))
|
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||||
data = torch.cat(cond_concat, dim=1)
|
data = torch.cat(cond_concat, dim=1)
|
||||||
out['c_concat'] = conds.CONDNoiseShape(data)
|
out['c_concat'] = conds.CONDNoiseShape(data)
|
||||||
adm = self.encode_adm(**kwargs)
|
adm = self.encode_adm(**kwargs)
|
||||||
@ -220,7 +213,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
return unet_state_dict
|
return unet_state_dict
|
||||||
|
|
||||||
def set_inpaint(self):
|
def set_inpaint(self):
|
||||||
self.inpaint_model = True
|
self.concat_keys = ("mask", "masked_image")
|
||||||
|
def blank_inpaint_image_like(latent_image):
|
||||||
|
blank_image = torch.ones_like(latent_image)
|
||||||
|
# these are the values for "zero" in pixel space translated to latent space
|
||||||
|
blank_image[:,0] *= 0.8223
|
||||||
|
blank_image[:,1] *= -0.6876
|
||||||
|
blank_image[:,2] *= 0.6364
|
||||||
|
blank_image[:,3] *= 0.1380
|
||||||
|
return blank_image
|
||||||
|
self.blank_inpaint_image_like = blank_inpaint_image_like
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
|
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
|
||||||
@ -471,6 +473,42 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
out['y'] = conds.CONDRegular(noise_level)
|
out['y'] = conds.CONDRegular(noise_level)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class IP2P:
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
image = kwargs.get("concat_latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
image = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
if image.shape[1:] != noise.shape[1:]:
|
||||||
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
|
||||||
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
|
out['c_concat'] = conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||||
|
adm = self.encode_adm(**kwargs)
|
||||||
|
if adm is not None:
|
||||||
|
out['y'] = conds.CONDRegular(adm)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
self.process_ip2p_image_in = lambda image: image
|
||||||
|
|
||||||
|
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
if model_type == ModelType.V_PREDICTION_EDM:
|
||||||
|
self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p
|
||||||
|
else:
|
||||||
|
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
|
||||||
|
|
||||||
|
|
||||||
class StableCascade_C(BaseModel):
|
class StableCascade_C(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||||
|
|||||||
@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
|
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
def model_config_from_unet_config(unet_config):
|
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||||
for model_config in supported_models.models:
|
for model_config in supported_models.models:
|
||||||
if model_config.matches(unet_config):
|
if model_config.matches(unet_config, state_dict):
|
||||||
return model_config(unet_config)
|
return model_config(unet_config)
|
||||||
|
|
||||||
logging.error("no match {}".format(unet_config))
|
logging.error("no match {}".format(unet_config))
|
||||||
@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
|
|||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||||
model_config = model_config_from_unet_config(unet_config)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return supported_models_base.BASE(unet_config)
|
return supported_models_base.BASE(unet_config)
|
||||||
else:
|
else:
|
||||||
@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
|
||||||
|
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||||
|
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
||||||
@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS]
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
|||||||
@ -396,6 +396,7 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
extra_mem = max(inference_memory, memory_required)
|
extra_mem = max(inference_memory, memory_required)
|
||||||
|
|
||||||
|
models = set(models)
|
||||||
models_to_load = []
|
models_to_load = []
|
||||||
models_already_loaded = []
|
models_already_loaded = []
|
||||||
for x in models:
|
for x in models:
|
||||||
|
|||||||
@ -150,6 +150,15 @@ class ModelPatcher:
|
|||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
|
def get_model_object(self, name):
|
||||||
|
if name in self.object_patches:
|
||||||
|
return self.object_patches[name]
|
||||||
|
else:
|
||||||
|
if name in self.object_patches_backup:
|
||||||
|
return self.object_patches_backup[name]
|
||||||
|
else:
|
||||||
|
return utils.get_attr(self.model, name)
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
def model_patches_to(self, device):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" in to:
|
if "patches" in to:
|
||||||
@ -278,7 +287,7 @@ class ModelPatcher:
|
|||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self)
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(weight_key, self)
|
m.bias_function = LowVramPatch(bias_key, self)
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@ -462,4 +471,4 @@ class ModelPatcher:
|
|||||||
for k in keys:
|
for k in keys:
|
||||||
utils.set_attr(self.model, k, self.object_patches_backup[k])
|
utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||||
|
|
||||||
self.object_patches_backup = {}
|
self.object_patches_backup.clear()
|
||||||
|
|||||||
10
comfy/node_helpers.py
Normal file
10
comfy/node_helpers.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
def conditioning_set_values(conditioning, values={}):
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
n = [t[0], t[1].copy()]
|
||||||
|
for k in values:
|
||||||
|
n[1][k] = values[k]
|
||||||
|
c.append(n)
|
||||||
|
|
||||||
|
return c
|
||||||
@ -30,7 +30,7 @@ from ..model_downloader import get_filename_list_with_downloadable, get_or_downl
|
|||||||
from ..nodes.common import MAX_RESOLUTION
|
from ..nodes.common import MAX_RESOLUTION
|
||||||
from .. import controlnet
|
from .. import controlnet
|
||||||
from ..open_exr import load_exr
|
from ..open_exr import load_exr
|
||||||
|
from .. import node_helpers
|
||||||
|
|
||||||
class CLIPTextEncode:
|
class CLIPTextEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -140,13 +140,9 @@ class ConditioningSetArea:
|
|||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def append(self, conditioning, width, height, x, y, strength):
|
def append(self, conditioning, width, height, x, y, strength):
|
||||||
c = []
|
c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
|
||||||
for t in conditioning:
|
"strength": strength,
|
||||||
n = [t[0], t[1].copy()]
|
"set_area_to_bounds": False})
|
||||||
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
|
||||||
n[1]['strength'] = strength
|
|
||||||
n[1]['set_area_to_bounds'] = False
|
|
||||||
c.append(n)
|
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class ConditioningSetAreaPercentage:
|
class ConditioningSetAreaPercentage:
|
||||||
@ -165,13 +161,9 @@ class ConditioningSetAreaPercentage:
|
|||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def append(self, conditioning, width, height, x, y, strength):
|
def append(self, conditioning, width, height, x, y, strength):
|
||||||
c = []
|
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
|
||||||
for t in conditioning:
|
"strength": strength,
|
||||||
n = [t[0], t[1].copy()]
|
"set_area_to_bounds": False})
|
||||||
n[1]['area'] = ("percentage", height, width, y, x)
|
|
||||||
n[1]['strength'] = strength
|
|
||||||
n[1]['set_area_to_bounds'] = False
|
|
||||||
c.append(n)
|
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class ConditioningSetAreaStrength:
|
class ConditioningSetAreaStrength:
|
||||||
@ -186,11 +178,7 @@ class ConditioningSetAreaStrength:
|
|||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def append(self, conditioning, strength):
|
def append(self, conditioning, strength):
|
||||||
c = []
|
c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
|
||||||
for t in conditioning:
|
|
||||||
n = [t[0], t[1].copy()]
|
|
||||||
n[1]['strength'] = strength
|
|
||||||
c.append(n)
|
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
@ -208,19 +196,15 @@ class ConditioningSetMask:
|
|||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def append(self, conditioning, mask, set_cond_area, strength):
|
def append(self, conditioning, mask, set_cond_area, strength):
|
||||||
c = []
|
|
||||||
set_area_to_bounds = False
|
set_area_to_bounds = False
|
||||||
if set_cond_area != "default":
|
if set_cond_area != "default":
|
||||||
set_area_to_bounds = True
|
set_area_to_bounds = True
|
||||||
if len(mask.shape) < 3:
|
if len(mask.shape) < 3:
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
for t in conditioning:
|
|
||||||
n = [t[0], t[1].copy()]
|
c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
|
||||||
_, h, w = mask.shape
|
"set_area_to_bounds": set_area_to_bounds,
|
||||||
n[1]['mask'] = mask
|
"mask_strength": strength})
|
||||||
n[1]['set_area_to_bounds'] = set_area_to_bounds
|
|
||||||
n[1]['mask_strength'] = strength
|
|
||||||
c.append(n)
|
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class ConditioningZeroOut:
|
class ConditioningZeroOut:
|
||||||
@ -255,13 +239,8 @@ class ConditioningSetTimestepRange:
|
|||||||
CATEGORY = "advanced/conditioning"
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
def set_range(self, conditioning, start, end):
|
def set_range(self, conditioning, start, end):
|
||||||
c = []
|
c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
|
||||||
for t in conditioning:
|
"end_percent": end})
|
||||||
d = t[1].copy()
|
|
||||||
d['start_percent'] = start
|
|
||||||
d['end_percent'] = end
|
|
||||||
n = [t[0], d]
|
|
||||||
c.append(n)
|
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
@ -402,13 +381,8 @@ class InpaintModelConditioning:
|
|||||||
|
|
||||||
out = []
|
out = []
|
||||||
for conditioning in [positive, negative]:
|
for conditioning in [positive, negative]:
|
||||||
c = []
|
c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
|
||||||
for t in conditioning:
|
"concat_mask": mask})
|
||||||
d = t[1].copy()
|
|
||||||
d["concat_latent_image"] = concat_latent
|
|
||||||
d["concat_mask"] = mask
|
|
||||||
n = [t[0], d]
|
|
||||||
c.append(n)
|
|
||||||
out.append(c)
|
out.append(c)
|
||||||
return (out[0], out[1], out_latent)
|
return (out[0], out[1], out_latent)
|
||||||
|
|
||||||
|
|||||||
@ -8,14 +8,14 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import types
|
import types
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Iterable
|
||||||
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
|
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
|
||||||
|
|
||||||
from . import base_nodes
|
from . import base_nodes
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
|
|
||||||
|
|
||||||
def _vanilla_load_importing_execute_prestartup_script(node_paths: List[str]) -> None:
|
def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None:
|
||||||
def execute_script(script_path):
|
def execute_script(script_path):
|
||||||
module_name = splitext(script_path)[0]
|
module_name = splitext(script_path)[0]
|
||||||
try:
|
try:
|
||||||
@ -121,7 +121,7 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
|
|||||||
return exported_nodes
|
return exported_nodes
|
||||||
|
|
||||||
|
|
||||||
def _vanilla_load_custom_nodes_2(node_paths: List[str]) -> ExportedNodes:
|
def _vanilla_load_custom_nodes_2(node_paths: Iterable[str]) -> ExportedNodes:
|
||||||
base_node_names = set(base_nodes.NODE_CLASS_MAPPINGS.keys())
|
base_node_names = set(base_nodes.NODE_CLASS_MAPPINGS.keys())
|
||||||
node_import_times = []
|
node_import_times = []
|
||||||
exported_nodes = ExportedNodes()
|
exported_nodes = ExportedNodes()
|
||||||
@ -192,6 +192,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
|||||||
if is_git_repository:
|
if is_git_repository:
|
||||||
node_paths += [abspath(join(potential_git_dir_parent, "custom_nodes"))]
|
node_paths += [abspath(join(potential_git_dir_parent, "custom_nodes"))]
|
||||||
|
|
||||||
|
node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths)
|
||||||
|
|
||||||
_vanilla_load_importing_execute_prestartup_script(node_paths)
|
_vanilla_load_importing_execute_prestartup_script(node_paths)
|
||||||
vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths)
|
vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths)
|
||||||
return vanilla_custom_nodes
|
return vanilla_custom_nodes
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from . import utils
|
|||||||
from . import conds
|
from . import conds
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||||
"""
|
"""
|
||||||
@ -25,94 +26,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
|||||||
noises = torch.cat(noises, axis=0)
|
noises = torch.cat(noises, axis=0)
|
||||||
return noises
|
return noises
|
||||||
|
|
||||||
def prepare_mask(noise_mask, shape, device):
|
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||||
"""ensures noise mask is of proper dimensions"""
|
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
|
||||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
return model, positive, negative, noise_mask, []
|
||||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
|
||||||
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
|
|
||||||
noise_mask = noise_mask.to(device)
|
|
||||||
return noise_mask
|
|
||||||
|
|
||||||
def get_models_from_cond(cond, model_type):
|
|
||||||
models = []
|
|
||||||
for c in cond:
|
|
||||||
if model_type in c:
|
|
||||||
models += [c[model_type]]
|
|
||||||
return models
|
|
||||||
|
|
||||||
def convert_cond(cond):
|
|
||||||
out = []
|
|
||||||
for c in cond:
|
|
||||||
temp = c[1].copy()
|
|
||||||
model_conds = temp.get("model_conds", {})
|
|
||||||
if c[0] is not None:
|
|
||||||
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
|
|
||||||
temp["cross_attn"] = c[0]
|
|
||||||
temp["model_conds"] = model_conds
|
|
||||||
out.append(temp)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def get_additional_models(positive, negative, dtype):
|
|
||||||
"""loads additional models in positive and negative conditioning"""
|
|
||||||
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
|
||||||
|
|
||||||
inference_memory = 0
|
|
||||||
control_models = []
|
|
||||||
for m in control_nets:
|
|
||||||
control_models += m.get_models()
|
|
||||||
inference_memory += m.inference_memory_requirements(dtype)
|
|
||||||
|
|
||||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
|
||||||
gligen = [x[1] for x in gligen]
|
|
||||||
models = control_models + gligen
|
|
||||||
return models, inference_memory
|
|
||||||
|
|
||||||
def cleanup_additional_models(models):
|
def cleanup_additional_models(models):
|
||||||
"""cleanup additional models that were loaded"""
|
logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
|
||||||
for m in models:
|
|
||||||
if hasattr(m, 'cleanup'):
|
|
||||||
m.cleanup()
|
|
||||||
|
|
||||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
|
||||||
device = model.load_device
|
|
||||||
positive = convert_cond(positive)
|
|
||||||
negative = convert_cond(negative)
|
|
||||||
|
|
||||||
if noise_mask is not None:
|
|
||||||
noise_mask = prepare_mask(noise_mask, noise_shape, device)
|
|
||||||
|
|
||||||
real_model = None
|
|
||||||
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
|
|
||||||
model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
|
||||||
real_model = model.model
|
|
||||||
|
|
||||||
return real_model, positive, negative, noise_mask, models
|
|
||||||
|
|
||||||
|
|
||||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||||
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
|
sampler = samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
noise = noise.to(model.load_device)
|
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
latent_image = latent_image.to(model.load_device)
|
|
||||||
|
|
||||||
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
|
||||||
samples = samples.to(model_management.intermediate_device())
|
samples = samples.to(model_management.intermediate_device())
|
||||||
|
|
||||||
cleanup_additional_models(models)
|
|
||||||
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
|
samples = samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
noise = noise.to(model.load_device)
|
|
||||||
latent_image = latent_image.to(model.load_device)
|
|
||||||
sigmas = sigmas.to(model.load_device)
|
|
||||||
|
|
||||||
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
|
||||||
samples = samples.to(model_management.intermediate_device())
|
samples = samples.to(model_management.intermediate_device())
|
||||||
cleanup_additional_models(models)
|
|
||||||
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|||||||
77
comfy/sampler_helpers.py
Normal file
77
comfy/sampler_helpers.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
from comfy import model_management
|
||||||
|
from comfy import utils
|
||||||
|
from comfy import conds
|
||||||
|
|
||||||
|
def prepare_mask(noise_mask, shape, device):
|
||||||
|
"""ensures noise mask is of proper dimensions"""
|
||||||
|
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||||
|
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||||
|
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||||
|
noise_mask = noise_mask.to(device)
|
||||||
|
return noise_mask
|
||||||
|
|
||||||
|
def get_models_from_cond(cond, model_type):
|
||||||
|
models = []
|
||||||
|
for c in cond:
|
||||||
|
if model_type in c:
|
||||||
|
models += [c[model_type]]
|
||||||
|
return models
|
||||||
|
|
||||||
|
def convert_cond(cond):
|
||||||
|
out = []
|
||||||
|
for c in cond:
|
||||||
|
temp = c[1].copy()
|
||||||
|
model_conds = temp.get("model_conds", {})
|
||||||
|
if c[0] is not None:
|
||||||
|
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||||
|
temp["cross_attn"] = c[0]
|
||||||
|
temp["model_conds"] = model_conds
|
||||||
|
out.append(temp)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def get_additional_models(conds, dtype):
|
||||||
|
"""loads additional models in conditioning"""
|
||||||
|
cnets = []
|
||||||
|
gligen = []
|
||||||
|
|
||||||
|
for k in conds:
|
||||||
|
cnets += get_models_from_cond(conds[k], "control")
|
||||||
|
gligen += get_models_from_cond(conds[k], "gligen")
|
||||||
|
|
||||||
|
control_nets = set(cnets)
|
||||||
|
|
||||||
|
inference_memory = 0
|
||||||
|
control_models = []
|
||||||
|
for m in control_nets:
|
||||||
|
control_models += m.get_models()
|
||||||
|
inference_memory += m.inference_memory_requirements(dtype)
|
||||||
|
|
||||||
|
gligen = [x[1] for x in gligen]
|
||||||
|
models = control_models + gligen
|
||||||
|
return models, inference_memory
|
||||||
|
|
||||||
|
def cleanup_additional_models(models):
|
||||||
|
"""cleanup additional models that were loaded"""
|
||||||
|
for m in models:
|
||||||
|
if hasattr(m, 'cleanup'):
|
||||||
|
m.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_sampling(model, noise_shape, conds):
|
||||||
|
device = model.load_device
|
||||||
|
real_model = None
|
||||||
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
|
model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
||||||
|
real_model = model.model
|
||||||
|
|
||||||
|
return real_model, conds, models
|
||||||
|
|
||||||
|
def cleanup_models(conds, models):
|
||||||
|
cleanup_additional_models(models)
|
||||||
|
|
||||||
|
control_cleanup = []
|
||||||
|
for k in conds:
|
||||||
|
control_cleanup += get_models_from_cond(conds[k], "control")
|
||||||
|
|
||||||
|
cleanup_additional_models(set(control_cleanup))
|
||||||
@ -5,6 +5,7 @@ import collections
|
|||||||
from . import model_management
|
from . import model_management
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
from . import sampler_helpers
|
||||||
|
|
||||||
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
|
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
|
||||||
|
|
||||||
@ -130,30 +131,23 @@ def cond_cat(c_list):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_conds = []
|
||||||
out_count = torch.ones_like(x_in) * 1e-37
|
out_counts = []
|
||||||
|
|
||||||
out_uncond = torch.zeros_like(x_in)
|
|
||||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
|
||||||
|
|
||||||
COND = 0
|
|
||||||
UNCOND = 1
|
|
||||||
|
|
||||||
to_run = []
|
to_run = []
|
||||||
for x in cond:
|
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
|
||||||
if p is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
for i in range(len(conds)):
|
||||||
if uncond is not None:
|
out_conds.append(torch.zeros_like(x_in))
|
||||||
for x in uncond:
|
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
|
||||||
if p is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_run += [(p, UNCOND)]
|
cond = conds[i]
|
||||||
|
if cond is not None:
|
||||||
|
for x in cond:
|
||||||
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
|
if p is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
to_run += [(p, i)]
|
||||||
|
|
||||||
while len(to_run) > 0:
|
while len(to_run) > 0:
|
||||||
first = to_run[0]
|
first = to_run[0]
|
||||||
@ -225,74 +219,66 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
|||||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
else:
|
else:
|
||||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||||
del input_x
|
|
||||||
|
|
||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
if cond_or_uncond[o] == COND:
|
cond_index = cond_or_uncond[o]
|
||||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||||
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||||
else:
|
|
||||||
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
||||||
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
||||||
del mult
|
|
||||||
|
|
||||||
out_cond /= out_count
|
for i in range(len(out_conds)):
|
||||||
del out_count
|
out_conds[i] /= out_counts[i]
|
||||||
out_uncond /= out_uncond_count
|
|
||||||
del out_uncond_count
|
return out_conds
|
||||||
return out_cond, out_uncond
|
|
||||||
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||||
|
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||||
|
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||||
|
|
||||||
|
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
|
||||||
|
if "sampler_cfg_function" in model_options:
|
||||||
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||||
|
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
||||||
|
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||||
|
else:
|
||||||
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||||
|
|
||||||
|
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||||
|
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||||
|
"sigma": timestep, "model_options": model_options, "input": x}
|
||||||
|
cfg_result = fn(args)
|
||||||
|
|
||||||
|
return cfg_result
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns denoised
|
#Returns denoised
|
||||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||||
uncond_ = None
|
uncond_ = None
|
||||||
else:
|
else:
|
||||||
uncond_ = uncond
|
uncond_ = uncond
|
||||||
|
|
||||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
conds = [cond, uncond_]
|
||||||
if "sampler_cfg_function" in model_options:
|
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
||||||
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
|
||||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
|
||||||
else:
|
|
||||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
|
||||||
|
|
||||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
|
||||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
|
||||||
"sigma": timestep, "model_options": model_options, "input": x}
|
|
||||||
cfg_result = fn(args)
|
|
||||||
|
|
||||||
return cfg_result
|
class KSamplerX0Inpaint:
|
||||||
|
|
||||||
class CFGNoisePredictor(torch.nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
|
||||||
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
|
||||||
return out
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return self.apply_model(*args, **kwargs)
|
|
||||||
|
|
||||||
class KSamplerX0Inpaint(torch.nn.Module):
|
|
||||||
def __init__(self, model, sigmas):
|
def __init__(self, model, sigmas):
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.sigmas = sigmas
|
self.sigmas = sigmas
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
if "denoise_mask_function" in model_options:
|
if "denoise_mask_function" in model_options:
|
||||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
||||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out = out * denoise_mask + self.latent_image * latent_mask
|
out = out * denoise_mask + self.latent_image * latent_mask
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def simple_scheduler(model, steps):
|
def simple_scheduler(model_sampling, steps):
|
||||||
s = model.model_sampling
|
s = model_sampling
|
||||||
sigs = []
|
sigs = []
|
||||||
ss = len(s.sigmas) / steps
|
ss = len(s.sigmas) / steps
|
||||||
for x in range(steps):
|
for x in range(steps):
|
||||||
@ -300,8 +286,8 @@ def simple_scheduler(model, steps):
|
|||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def ddim_scheduler(model, steps):
|
def ddim_scheduler(model_sampling, steps):
|
||||||
s = model.model_sampling
|
s = model_sampling
|
||||||
sigs = []
|
sigs = []
|
||||||
ss = max(len(s.sigmas) // steps, 1)
|
ss = max(len(s.sigmas) // steps, 1)
|
||||||
x = 1
|
x = 1
|
||||||
@ -312,8 +298,8 @@ def ddim_scheduler(model, steps):
|
|||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def normal_scheduler(model, steps, sgm=False, floor=False):
|
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
||||||
s = model.model_sampling
|
s = model_sampling
|
||||||
start = s.timestep(s.sigma_max)
|
start = s.timestep(s.sigma_max)
|
||||||
end = s.timestep(s.sigma_min)
|
end = s.timestep(s.sigma_min)
|
||||||
|
|
||||||
@ -574,59 +560,120 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
|||||||
|
|
||||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||||
|
|
||||||
def wrap_model(model):
|
|
||||||
model_denoise = CFGNoisePredictor(model)
|
|
||||||
return model_denoise
|
|
||||||
|
|
||||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||||
positive = positive[:]
|
for k in conds:
|
||||||
negative = negative[:]
|
conds[k] = conds[k][:]
|
||||||
|
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
|
||||||
|
|
||||||
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
|
for k in conds:
|
||||||
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
|
calculate_start_end_timesteps(model, conds[k])
|
||||||
|
|
||||||
model_wrap = wrap_model(model)
|
|
||||||
|
|
||||||
calculate_start_end_timesteps(model, negative)
|
|
||||||
calculate_start_end_timesteps(model, positive)
|
|
||||||
|
|
||||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
|
||||||
latent_image = model.process_latent_in(latent_image)
|
|
||||||
|
|
||||||
if hasattr(model, 'extra_conds'):
|
if hasattr(model, 'extra_conds'):
|
||||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
for k in conds:
|
||||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for k in conds:
|
||||||
create_cond_with_same_area_if_none(negative, c)
|
for c in conds[k]:
|
||||||
for c in negative:
|
for kk in conds:
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
if k != kk:
|
||||||
|
create_cond_with_same_area_if_none(conds[kk], c)
|
||||||
|
|
||||||
pre_run_control(model, negative + positive)
|
for k in conds:
|
||||||
|
pre_run_control(model, conds[k])
|
||||||
|
|
||||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
if "positive" in conds:
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
positive = conds["positive"]
|
||||||
|
for k in conds:
|
||||||
|
if k != "positive":
|
||||||
|
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
return conds
|
||||||
|
|
||||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
class CFGGuider:
|
||||||
return model.process_latent_out(samples.to(torch.float32))
|
def __init__(self, model_patcher):
|
||||||
|
self.model_patcher = model_patcher
|
||||||
|
self.model_options = model_patcher.model_options
|
||||||
|
self.original_conds = {}
|
||||||
|
self.cfg = 1.0
|
||||||
|
|
||||||
|
def set_conds(self, positive, negative):
|
||||||
|
self.inner_set_conds({"positive": positive, "negative": negative})
|
||||||
|
|
||||||
|
def set_cfg(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def inner_set_conds(self, conds):
|
||||||
|
for k in conds:
|
||||||
|
self.original_conds[k] = sampler_helpers.convert_cond(conds[k])
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.predict_noise(*args, **kwargs)
|
||||||
|
|
||||||
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
|
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||||
|
|
||||||
|
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
||||||
|
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||||
|
latent_image = self.inner_model.process_latent_in(latent_image)
|
||||||
|
|
||||||
|
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||||
|
|
||||||
|
extra_args = {"model_options": self.model_options, "seed":seed}
|
||||||
|
|
||||||
|
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||||
|
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
|
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
|
if sigmas.shape[-1] == 0:
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
self.conds = {}
|
||||||
|
for k in self.original_conds:
|
||||||
|
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||||
|
|
||||||
|
self.inner_model, self.conds, self.loaded_models = sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
|
||||||
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
denoise_mask = sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||||
|
|
||||||
|
noise = noise.to(device)
|
||||||
|
latent_image = latent_image.to(device)
|
||||||
|
sigmas = sigmas.to(device)
|
||||||
|
|
||||||
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
|
||||||
|
sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||||
|
del self.inner_model
|
||||||
|
del self.conds
|
||||||
|
del self.loaded_models
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
|
cfg_guider = CFGGuider(model)
|
||||||
|
cfg_guider.set_conds(positive, negative)
|
||||||
|
cfg_guider.set_cfg(cfg)
|
||||||
|
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||||
if scheduler_name == "karras":
|
if scheduler_name == "karras":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||||
elif scheduler_name == "exponential":
|
elif scheduler_name == "exponential":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||||
elif scheduler_name == "normal":
|
elif scheduler_name == "normal":
|
||||||
sigmas = normal_scheduler(model, steps)
|
sigmas = normal_scheduler(model_sampling, steps)
|
||||||
elif scheduler_name == "simple":
|
elif scheduler_name == "simple":
|
||||||
sigmas = simple_scheduler(model, steps)
|
sigmas = simple_scheduler(model_sampling, steps)
|
||||||
elif scheduler_name == "ddim_uniform":
|
elif scheduler_name == "ddim_uniform":
|
||||||
sigmas = ddim_scheduler(model, steps)
|
sigmas = ddim_scheduler(model_sampling, steps)
|
||||||
elif scheduler_name == "sgm_uniform":
|
elif scheduler_name == "sgm_uniform":
|
||||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||||
else:
|
else:
|
||||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||||
return sigmas
|
return sigmas
|
||||||
@ -668,7 +715,7 @@ class KSampler:
|
|||||||
steps += 1
|
steps += 1
|
||||||
discard_penultimate_sigma = True
|
discard_penultimate_sigma = True
|
||||||
|
|
||||||
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps)
|
sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)
|
||||||
|
|
||||||
if discard_penultimate_sigma:
|
if discard_penultimate_sigma:
|
||||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||||
@ -679,9 +726,12 @@ class KSampler:
|
|||||||
if denoise is None or denoise > 0.9999:
|
if denoise is None or denoise > 0.9999:
|
||||||
self.sigmas = self.calculate_sigmas(steps).to(self.device)
|
self.sigmas = self.calculate_sigmas(steps).to(self.device)
|
||||||
else:
|
else:
|
||||||
new_steps = int(steps/denoise)
|
if denoise <= 0.0:
|
||||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
self.sigmas = torch.FloatTensor([])
|
||||||
self.sigmas = sigmas[-(steps + 1):]
|
else:
|
||||||
|
new_steps = int(steps/denoise)
|
||||||
|
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||||
|
self.sigmas = sigmas[-(steps + 1):]
|
||||||
|
|
||||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||||
if sigmas is None:
|
if sigmas is None:
|
||||||
|
|||||||
@ -600,7 +600,7 @@ def load_unet(unet_path):
|
|||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
|
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||||
clip_sd = None
|
clip_sd = None
|
||||||
load_models = [model]
|
load_models = [model]
|
||||||
if clip is not None:
|
if clip is not None:
|
||||||
@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||||||
model_management.load_models_gpu(load_models)
|
model_management.load_models_gpu(load_models)
|
||||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||||
|
for k in extra_keys:
|
||||||
|
sd[k] = extra_keys[k]
|
||||||
|
|
||||||
utils.save_torch_file(sd, output_path, metadata=metadata)
|
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||||
|
|||||||
@ -174,6 +174,11 @@ class SDXL(supported_models_base.BASE):
|
|||||||
self.sampling_settings["sigma_max"] = 80.0
|
self.sampling_settings["sigma_max"] = 80.0
|
||||||
self.sampling_settings["sigma_min"] = 0.002
|
self.sampling_settings["sigma_min"] = 0.002
|
||||||
return model_base.ModelType.EDM
|
return model_base.ModelType.EDM
|
||||||
|
elif "edm_vpred.sigma_max" in state_dict:
|
||||||
|
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
|
||||||
|
if "edm_vpred.sigma_min" in state_dict:
|
||||||
|
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||||
|
return model_base.ModelType.V_PREDICTION_EDM
|
||||||
elif "v_pred" in state_dict:
|
elif "v_pred" in state_dict:
|
||||||
return model_base.ModelType.V_PREDICTION
|
return model_base.ModelType.V_PREDICTION
|
||||||
else:
|
else:
|
||||||
@ -334,6 +339,11 @@ class Stable_Zero123(supported_models_base.BASE):
|
|||||||
"num_head_channels": -1,
|
"num_head_channels": -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
required_keys = {
|
||||||
|
"cc_projection.weight": None,
|
||||||
|
"cc_projection.bias": None,
|
||||||
|
}
|
||||||
|
|
||||||
clip_vision_prefix = "cond_stage_model.model.visual."
|
clip_vision_prefix = "cond_stage_model.model.visual."
|
||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
@ -439,6 +449,33 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
|||||||
out = model_base.StableCascade_B(self, device=device)
|
out = model_base.StableCascade_B(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class SD15_instructpix2pix(SD15):
|
||||||
|
unet_config = {
|
||||||
|
"context_dim": 768,
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": False,
|
||||||
|
"adm_in_channels": None,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
"in_channels": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.SD15_instructpix2pix(self, device=device)
|
||||||
|
|
||||||
|
class SDXL_instructpix2pix(SDXL):
|
||||||
|
unet_config = {
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": True,
|
||||||
|
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||||
|
"context_dim": 2048,
|
||||||
|
"adm_in_channels": 2816,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
"in_channels": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -16,6 +16,8 @@ class BASE:
|
|||||||
"num_head_channels": 64,
|
"num_head_channels": 64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
required_keys = {}
|
||||||
|
|
||||||
clip_prefix = []
|
clip_prefix = []
|
||||||
clip_vision_prefix = None
|
clip_vision_prefix = None
|
||||||
noise_aug_config = None
|
noise_aug_config = None
|
||||||
@ -28,10 +30,14 @@ class BASE:
|
|||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config):
|
def matches(s, unet_config, state_dict=None):
|
||||||
for k in s.unet_config:
|
for k in s.unet_config:
|
||||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||||
return False
|
return False
|
||||||
|
if state_dict is not None:
|
||||||
|
for k in s.required_keys:
|
||||||
|
if k not in state_dict:
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def model_type(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
|
|||||||
@ -7247,7 +7247,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
//create links
|
//create links
|
||||||
for (var i = 0; i < clipboard_info.links.length; ++i) {
|
for (var i = 0; i < clipboard_info.links.length; ++i) {
|
||||||
var link_info = clipboard_info.links[i];
|
var link_info = clipboard_info.links[i];
|
||||||
var origin_node;
|
var origin_node = undefined;
|
||||||
var origin_node_relative_id = link_info[0];
|
var origin_node_relative_id = link_info[0];
|
||||||
if (origin_node_relative_id != null) {
|
if (origin_node_relative_id != null) {
|
||||||
origin_node = nodes[origin_node_relative_id];
|
origin_node = nodes[origin_node_relative_id];
|
||||||
|
|||||||
@ -170,9 +170,12 @@ export async function importA1111(graph, parameters) {
|
|||||||
const opts = parameters
|
const opts = parameters
|
||||||
.substr(p)
|
.substr(p)
|
||||||
.split("\n")[1]
|
.split("\n")[1]
|
||||||
.split(",")
|
.match(new RegExp("\\s*([^:]+:\\s*([^\"\\{].*?|\".*?\"|\\{.*?\\}))\\s*(,|$)", "g"))
|
||||||
.reduce((p, n) => {
|
.reduce((p, n) => {
|
||||||
const s = n.split(":");
|
const s = n.split(":");
|
||||||
|
if (s[1].endsWith(',')) {
|
||||||
|
s[1] = s[1].substr(0, s[1].length -1);
|
||||||
|
}
|
||||||
p[s[0].trim().toLowerCase()] = s[1].trim();
|
p[s[0].trim().toLowerCase()] = s[1].trim();
|
||||||
return p;
|
return p;
|
||||||
}, {});
|
}, {});
|
||||||
@ -191,6 +194,7 @@ export async function importA1111(graph, parameters) {
|
|||||||
const vaeLoaderNode = LiteGraph.createNode("VAELoader");
|
const vaeLoaderNode = LiteGraph.createNode("VAELoader");
|
||||||
const saveNode = LiteGraph.createNode("SaveImage");
|
const saveNode = LiteGraph.createNode("SaveImage");
|
||||||
let hrSamplerNode = null;
|
let hrSamplerNode = null;
|
||||||
|
let hrSteps = null;
|
||||||
|
|
||||||
const ceil64 = (v) => Math.ceil(v / 64) * 64;
|
const ceil64 = (v) => Math.ceil(v / 64) * 64;
|
||||||
|
|
||||||
@ -290,6 +294,9 @@ export async function importA1111(graph, parameters) {
|
|||||||
model(v) {
|
model(v) {
|
||||||
setWidgetValue(ckptNode, "ckpt_name", v, true);
|
setWidgetValue(ckptNode, "ckpt_name", v, true);
|
||||||
},
|
},
|
||||||
|
"vae"(v) {
|
||||||
|
setWidgetValue(vaeLoaderNode, "vae_name", v, true);
|
||||||
|
},
|
||||||
"cfg scale"(v) {
|
"cfg scale"(v) {
|
||||||
setWidgetValue(samplerNode, "cfg", +v);
|
setWidgetValue(samplerNode, "cfg", +v);
|
||||||
},
|
},
|
||||||
@ -316,6 +323,7 @@ export async function importA1111(graph, parameters) {
|
|||||||
const h = ceil64(+wxh[1]);
|
const h = ceil64(+wxh[1]);
|
||||||
const hrUp = popOpt("hires upscale");
|
const hrUp = popOpt("hires upscale");
|
||||||
const hrSz = popOpt("hires resize");
|
const hrSz = popOpt("hires resize");
|
||||||
|
hrSteps = popOpt("hires steps");
|
||||||
let hrMethod = popOpt("hires upscaler");
|
let hrMethod = popOpt("hires upscaler");
|
||||||
|
|
||||||
setWidgetValue(imageNode, "width", w);
|
setWidgetValue(imageNode, "width", w);
|
||||||
@ -398,7 +406,7 @@ export async function importA1111(graph, parameters) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (hrSamplerNode) {
|
if (hrSamplerNode) {
|
||||||
setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value);
|
setWidgetValue(hrSamplerNode, "steps", hrSteps? +hrSteps : getWidget(samplerNode, "steps").value);
|
||||||
setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value);
|
setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value);
|
||||||
setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value);
|
setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value);
|
||||||
setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value);
|
setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value);
|
||||||
@ -415,7 +423,7 @@ export async function importA1111(graph, parameters) {
|
|||||||
|
|
||||||
graph.arrange();
|
graph.arrange();
|
||||||
|
|
||||||
for (const opt of ["model hash", "ensd"]) {
|
for (const opt of ["model hash", "ensd", "version", "vae hash", "ti hashes", "lora hashes", "hashes"]) {
|
||||||
delete opts[opt];
|
delete opts[opt];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from comfy.k_diffusion import sampling as k_diffusion_sampling
|
|||||||
from comfy.cmd import latent_preview
|
from comfy.cmd import latent_preview
|
||||||
import torch
|
import torch
|
||||||
from comfy import utils
|
from comfy import utils
|
||||||
|
from comfy import node_helpers
|
||||||
|
|
||||||
|
|
||||||
class BasicScheduler:
|
class BasicScheduler:
|
||||||
@ -26,10 +27,11 @@ class BasicScheduler:
|
|||||||
def get_sigmas(self, model, scheduler, steps, denoise):
|
def get_sigmas(self, model, scheduler, steps, denoise):
|
||||||
total_steps = steps
|
total_steps = steps
|
||||||
if denoise < 1.0:
|
if denoise < 1.0:
|
||||||
|
if denoise <= 0.0:
|
||||||
|
return (torch.FloatTensor([]),)
|
||||||
total_steps = int(steps/denoise)
|
total_steps = int(steps/denoise)
|
||||||
|
|
||||||
model_management.load_models_gpu([model])
|
sigmas = samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
|
||||||
sigmas = samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
|
|
||||||
sigmas = sigmas[-(steps + 1):]
|
sigmas = sigmas[-(steps + 1):]
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
@ -162,6 +164,9 @@ class FlipSigmas:
|
|||||||
FUNCTION = "get_sigmas"
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
def get_sigmas(self, sigmas):
|
def get_sigmas(self, sigmas):
|
||||||
|
if len(sigmas) == 0:
|
||||||
|
return (sigmas,)
|
||||||
|
|
||||||
sigmas = sigmas.flip(0)
|
sigmas = sigmas.flip(0)
|
||||||
if sigmas[0] == 0:
|
if sigmas[0] == 0:
|
||||||
sigmas[0] = 0.0001
|
sigmas[0] = 0.0001
|
||||||
@ -334,6 +339,24 @@ class SamplerDPMAdaptative:
|
|||||||
"s_noise":s_noise })
|
"s_noise":s_noise })
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
|
class Noise_EmptyNoise:
|
||||||
|
def __init__(self):
|
||||||
|
self.seed = 0
|
||||||
|
|
||||||
|
def generate_noise(self, input_latent):
|
||||||
|
latent_image = input_latent["samples"]
|
||||||
|
return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||||
|
|
||||||
|
|
||||||
|
class Noise_RandomNoise:
|
||||||
|
def __init__(self, seed):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def generate_noise(self, input_latent):
|
||||||
|
latent_image = input_latent["samples"]
|
||||||
|
batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
|
||||||
|
return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
|
||||||
|
|
||||||
class SamplerCustom:
|
class SamplerCustom:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -361,10 +384,9 @@ class SamplerCustom:
|
|||||||
latent = latent_image
|
latent = latent_image
|
||||||
latent_image = latent["samples"]
|
latent_image = latent["samples"]
|
||||||
if not add_noise:
|
if not add_noise:
|
||||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
noise = Noise_EmptyNoise().generate_noise(latent)
|
||||||
else:
|
else:
|
||||||
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
|
||||||
noise = sample.prepare_noise(latent_image, noise_seed, batch_inds)
|
|
||||||
|
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
@ -385,6 +407,161 @@ class SamplerCustom:
|
|||||||
out_denoised = out
|
out_denoised = out
|
||||||
return (out, out_denoised)
|
return (out, out_denoised)
|
||||||
|
|
||||||
|
class Guider_Basic(comfy.samplers.CFGGuider):
|
||||||
|
def set_conds(self, positive):
|
||||||
|
self.inner_set_conds({"positive": positive})
|
||||||
|
|
||||||
|
class BasicGuider:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"model": ("MODEL",),
|
||||||
|
"conditioning": ("CONDITIONING", ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("GUIDER",)
|
||||||
|
|
||||||
|
FUNCTION = "get_guider"
|
||||||
|
CATEGORY = "sampling/custom_sampling/guiders"
|
||||||
|
|
||||||
|
def get_guider(self, model, conditioning):
|
||||||
|
guider = Guider_Basic(model)
|
||||||
|
guider.set_conds(conditioning)
|
||||||
|
return (guider,)
|
||||||
|
|
||||||
|
class CFGGuider:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"model": ("MODEL",),
|
||||||
|
"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("GUIDER",)
|
||||||
|
|
||||||
|
FUNCTION = "get_guider"
|
||||||
|
CATEGORY = "sampling/custom_sampling/guiders"
|
||||||
|
|
||||||
|
def get_guider(self, model, positive, negative, cfg):
|
||||||
|
guider = comfy.samplers.CFGGuider(model)
|
||||||
|
guider.set_conds(positive, negative)
|
||||||
|
guider.set_cfg(cfg)
|
||||||
|
return (guider,)
|
||||||
|
|
||||||
|
class Guider_DualCFG(comfy.samplers.CFGGuider):
|
||||||
|
def set_cfg(self, cfg1, cfg2):
|
||||||
|
self.cfg1 = cfg1
|
||||||
|
self.cfg2 = cfg2
|
||||||
|
|
||||||
|
def set_conds(self, positive, middle, negative):
|
||||||
|
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
|
||||||
|
self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
|
||||||
|
|
||||||
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
|
negative_cond = self.conds.get("negative", None)
|
||||||
|
middle_cond = self.conds.get("middle", None)
|
||||||
|
|
||||||
|
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
|
||||||
|
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
||||||
|
|
||||||
|
class DualCFGGuider:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"model": ("MODEL",),
|
||||||
|
"cond1": ("CONDITIONING", ),
|
||||||
|
"cond2": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||||
|
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("GUIDER",)
|
||||||
|
|
||||||
|
FUNCTION = "get_guider"
|
||||||
|
CATEGORY = "sampling/custom_sampling/guiders"
|
||||||
|
|
||||||
|
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
|
||||||
|
guider = Guider_DualCFG(model)
|
||||||
|
guider.set_conds(cond1, cond2, negative)
|
||||||
|
guider.set_cfg(cfg_conds, cfg_cond2_negative)
|
||||||
|
return (guider,)
|
||||||
|
|
||||||
|
class DisableNoise:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":{
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("NOISE",)
|
||||||
|
FUNCTION = "get_noise"
|
||||||
|
CATEGORY = "sampling/custom_sampling/noise"
|
||||||
|
|
||||||
|
def get_noise(self):
|
||||||
|
return (Noise_EmptyNoise(),)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomNoise(DisableNoise):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":{
|
||||||
|
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_noise(self, noise_seed):
|
||||||
|
return (Noise_RandomNoise(noise_seed),)
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerCustomAdvanced:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"noise": ("NOISE", ),
|
||||||
|
"guider": ("GUIDER", ),
|
||||||
|
"sampler": ("SAMPLER", ),
|
||||||
|
"sigmas": ("SIGMAS", ),
|
||||||
|
"latent_image": ("LATENT", ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT","LATENT")
|
||||||
|
RETURN_NAMES = ("output", "denoised_output")
|
||||||
|
|
||||||
|
FUNCTION = "sample"
|
||||||
|
|
||||||
|
CATEGORY = "sampling/custom_sampling"
|
||||||
|
|
||||||
|
def sample(self, noise, guider, sampler, sigmas, latent_image):
|
||||||
|
latent = latent_image
|
||||||
|
latent_image = latent["samples"]
|
||||||
|
|
||||||
|
noise_mask = None
|
||||||
|
if "noise_mask" in latent:
|
||||||
|
noise_mask = latent["noise_mask"]
|
||||||
|
|
||||||
|
x0_output = {}
|
||||||
|
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
|
||||||
|
|
||||||
|
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||||
|
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
|
||||||
|
samples = samples.to(comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
out = latent.copy()
|
||||||
|
out["samples"] = samples
|
||||||
|
if "x0" in x0_output:
|
||||||
|
out_denoised = latent.copy()
|
||||||
|
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
|
||||||
|
else:
|
||||||
|
out_denoised = out
|
||||||
|
return (out, out_denoised)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SamplerCustom": SamplerCustom,
|
"SamplerCustom": SamplerCustom,
|
||||||
"BasicScheduler": BasicScheduler,
|
"BasicScheduler": BasicScheduler,
|
||||||
@ -402,4 +579,11 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||||
"SplitSigmas": SplitSigmas,
|
"SplitSigmas": SplitSigmas,
|
||||||
"FlipSigmas": FlipSigmas,
|
"FlipSigmas": FlipSigmas,
|
||||||
|
|
||||||
|
"CFGGuider": CFGGuider,
|
||||||
|
"DualCFGGuider": DualCFGGuider,
|
||||||
|
"BasicGuider": BasicGuider,
|
||||||
|
"RandomNoise": RandomNoise,
|
||||||
|
"DisableNoise": DisableNoise,
|
||||||
|
"SamplerCustomAdvanced": SamplerCustomAdvanced,
|
||||||
}
|
}
|
||||||
|
|||||||
45
comfy_extras/nodes/nodes_ip2p.py
Normal file
45
comfy_extras/nodes/nodes_ip2p.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class InstructPixToPixConditioning:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"vae": ("VAE", ),
|
||||||
|
"pixels": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/instructpix2pix"
|
||||||
|
|
||||||
|
def encode(self, positive, negative, pixels, vae):
|
||||||
|
x = (pixels.shape[1] // 8) * 8
|
||||||
|
y = (pixels.shape[2] // 8) * 8
|
||||||
|
|
||||||
|
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||||
|
x_offset = (pixels.shape[1] % 8) // 2
|
||||||
|
y_offset = (pixels.shape[2] % 8) // 2
|
||||||
|
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||||
|
|
||||||
|
concat_latent = vae.encode(pixels)
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = torch.zeros_like(concat_latent)
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for conditioning in [positive, negative]:
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
d["concat_latent_image"] = concat_latent
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
out.append(c)
|
||||||
|
return (out[0], out[1], out_latent)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"InstructPixToPixConditioning": InstructPixToPixConditioning,
|
||||||
|
}
|
||||||
@ -1,8 +1,10 @@
|
|||||||
from comfy import sd, utils
|
from comfy import sd, utils
|
||||||
from comfy import model_base
|
from comfy import model_base
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
|
from comfy import model_sampling
|
||||||
from comfy.cmd import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
|
|
||||||
|
import torch
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -188,6 +190,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|||||||
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
||||||
# "v2-inpainting"
|
# "v2-inpainting"
|
||||||
|
|
||||||
|
extra_keys = {}
|
||||||
|
_model_sampling = model.get_model_object("model_sampling")
|
||||||
|
if isinstance(_model_sampling, model_sampling.ModelSamplingContinuousEDM):
|
||||||
|
if isinstance(_model_sampling, model_sampling.V_PREDICTION):
|
||||||
|
extra_keys["edm_vpred.sigma_max"] = torch.tensor(_model_sampling.sigma_max).float()
|
||||||
|
extra_keys["edm_vpred.sigma_min"] = torch.tensor(_model_sampling.sigma_min).float()
|
||||||
|
|
||||||
if model.model.model_type == model_base.ModelType.EPS:
|
if model.model.model_type == model_base.ModelType.EPS:
|
||||||
metadata["modelspec.predict_key"] = "epsilon"
|
metadata["modelspec.predict_key"] = "epsilon"
|
||||||
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
|
elif model.model.model_type == model_base.ModelType.V_PREDICTION:
|
||||||
@ -202,7 +211,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|||||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
|
sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||||
|
|
||||||
class CheckpointSave:
|
class CheckpointSave:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
from comfy import sample
|
from comfy import sample
|
||||||
from comfy import samplers
|
from comfy import samplers
|
||||||
|
from comfy import sampler_helpers
|
||||||
|
|
||||||
|
|
||||||
|
#TODO: This node should be removed and replaced with one that uses the new Guider/SamplerCustomAdvanced.
|
||||||
class PerpNeg:
|
class PerpNeg:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -17,7 +19,7 @@ class PerpNeg:
|
|||||||
|
|
||||||
def patch(self, model, empty_conditioning, neg_scale):
|
def patch(self, model, empty_conditioning, neg_scale):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
nocond = sample.convert_cond(empty_conditioning)
|
nocond = sampler_helpers.convert_cond(empty_conditioning)
|
||||||
|
|
||||||
def cfg_function(args):
|
def cfg_function(args):
|
||||||
model = args["model"]
|
model = args["model"]
|
||||||
@ -29,7 +31,7 @@ class PerpNeg:
|
|||||||
model_options = args["model_options"]
|
model_options = args["model_options"]
|
||||||
nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
|
nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
|
||||||
|
|
||||||
(noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
|
(noise_pred_nocond,) = samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
|
||||||
|
|
||||||
pos = noise_pred_pos - noise_pred_nocond
|
pos = noise_pred_pos - noise_pred_nocond
|
||||||
neg = noise_pred_neg - noise_pred_nocond
|
neg = noise_pred_neg - noise_pred_nocond
|
||||||
|
|||||||
@ -150,7 +150,7 @@ class SelfAttentionGuidance:
|
|||||||
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
||||||
degraded_noised = degraded + x - uncond_pred
|
degraded_noised = degraded + x - uncond_pred
|
||||||
# call into the UNet
|
# call into the UNet
|
||||||
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
(sag,) = samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
|
||||||
return cfg_result + (degraded - sag) * sag_scale
|
return cfg_result + (degraded - sag) * sag_scale
|
||||||
|
|
||||||
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user