Merge branch 'master' into execution_model_inversion

This commit is contained in:
Jacob Segal 2024-04-21 21:31:58 -07:00
commit b3e547f22b
41 changed files with 1144 additions and 356 deletions

View File

@ -142,7 +142,7 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly). 1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. 1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). 1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly. 1. Launch ComfyUI by running `python main.py`
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).

View File

@ -138,11 +138,13 @@ class ControlBase:
return out return out
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) if control_model is not None:
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
@ -183,7 +185,9 @@ class ControlNet(ControlBase):
return self.control_merge(None, control, control_prev, output_dtype) return self.control_merge(None, control, control_prev, output_dtype)
def copy(self): def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c) self.copy_to(c)
return c return c

View File

@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2} code2idx = {"q": 0, "k": 1, "v": 2}
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
def cat_tensors(tensors):
x = 0
for t in tensors:
x += t.shape[0]
shape = [x] + list(tensors[0].shape)[1:]
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
x = 0
for t in tensors:
out[x:x + t.shape[0]] = t
x += t.shape[0]
return out
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {} new_state_dict = {}
@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
if None in tensors: if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
for k_pre, tensors in capture_qkv_bias.items(): for k_pre, tensors in capture_qkv_bias.items():
if None in tensors: if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
return new_state_dict return new_state_dict

View File

@ -21,6 +21,12 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item() alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name) loaded_keys.add(alpha_name)
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
regular_lora = "{}.lora_up.weight".format(x) regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
@ -44,7 +50,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys(): if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
loaded_keys.add(mid_name) loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
@ -65,7 +71,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name) loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
@ -117,7 +123,7 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name) loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
#glora #glora
a1_name = "{}.a1.weight".format(x) a1_name = "{}.a1.weight".format(x)
@ -125,7 +131,7 @@ def load_lora(lora, to_load):
b1_name = "{}.b1.weight".format(x) b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x) b2_name = "{}.b2.weight".format(x)
if a1_name in lora: if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name) loaded_keys.add(a1_name)
loaded_keys.add(a2_name) loaded_keys.add(a2_name)
loaded_keys.add(b1_name) loaded_keys.add(b1_name)

View File

@ -66,7 +66,8 @@ class BaseModel(torch.nn.Module):
self.adm_channels = unet_config.get("adm_in_channels", None) self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None: if self.adm_channels is None:
self.adm_channels = 0 self.adm_channels = 0
self.inpaint_model = False
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name)) logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
@ -107,8 +108,7 @@ class BaseModel(torch.nn.Module):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
if self.inpaint_model: if len(self.concat_keys) > 0:
concat_keys = ("mask", "masked_image")
cond_concat = [] cond_concat = []
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
concat_latent_image = kwargs.get("concat_latent_image", None) concat_latent_image = kwargs.get("concat_latent_image", None)
@ -125,24 +125,16 @@ class BaseModel(torch.nn.Module):
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if len(denoise_mask.shape) == len(noise.shape): if denoise_mask is not None:
denoise_mask = denoise_mask[:,:1] if len(denoise_mask.shape) == len(noise.shape):
denoise_mask = denoise_mask[:,:1]
denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
if denoise_mask.shape[-2:] != noise.shape[-2:]: if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
def blank_inpaint_image_like(latent_image): for ck in self.concat_keys:
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
for ck in concat_keys:
if denoise_mask is not None: if denoise_mask is not None:
if ck == "mask": if ck == "mask":
cond_concat.append(denoise_mask.to(device)) cond_concat.append(denoise_mask.to(device))
@ -152,7 +144,7 @@ class BaseModel(torch.nn.Module):
if ck == "mask": if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise)) cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data) out['c_concat'] = comfy.conds.CONDNoiseShape(data)
@ -221,7 +213,16 @@ class BaseModel(torch.nn.Module):
return unet_state_dict return unet_state_dict
def set_inpaint(self): def set_inpaint(self):
self.inpaint_model = True self.concat_keys = ("mask", "masked_image")
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
self.blank_inpaint_image_like = blank_inpaint_image_like
def memory_required(self, input_shape): def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
@ -472,6 +473,42 @@ class SD_X4Upscaler(BaseModel):
out['y'] = comfy.conds.CONDRegular(noise_level) out['y'] = comfy.conds.CONDRegular(noise_level)
return out return out
class IP2P:
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
if model_type == ModelType.V_PREDICTION_EDM:
self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) #cosxl ip2p
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
class StableCascade_C(BaseModel): class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC) super().__init__(model_config, model_type, device=device, unet_model=StageC)

View File

@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
return unet_config return unet_config
def model_config_from_unet_config(unet_config): def model_config_from_unet_config(unet_config, state_dict=None):
for model_config in comfy.supported_models.models: for model_config in comfy.supported_models.models:
if model_config.matches(unet_config): if model_config.matches(unet_config, state_dict):
return model_config(unet_config) return model_config(unet_config)
logging.error("no match {}".format(unet_config)) logging.error("no match {}".format(unet_config))
@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix) unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config) model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config) return comfy.supported_models_base.BASE(unet_config)
else: else:
@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
@ -345,7 +351,20 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B] SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True

View File

@ -274,6 +274,7 @@ class LoadedModel:
self.model = model self.model = model
self.device = model.load_device self.device = model.load_device
self.weights_loaded = False self.weights_loaded = False
self.real_model = None
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
@ -312,6 +313,7 @@ class LoadedModel:
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [i] + to_unload to_unload = [i] + to_unload
if len(to_unload) == 0: if len(to_unload) == 0:
return None return True
same_weights = 0 same_weights = 0
for i in to_unload: for i in to_unload:
@ -349,20 +351,27 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
return unload_weight return unload_weight
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False unloaded_model = []
can_unload = []
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if shift_model.device == device:
if shift_model not in keep_loaded: if shift_model not in keep_loaded:
m = current_loaded_models.pop(i) can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
m.model_unload()
del m
unloaded_model = True
if unloaded_model: for x in sorted(can_unload):
i = x[-1]
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
current_loaded_models.pop(i)
if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
else: else:
if vram_state != VRAMState.HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
@ -376,6 +385,8 @@ def load_models_gpu(models, memory_required=0):
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required) extra_mem = max(inference_memory, memory_required)
models = set(models)
models_to_load = [] models_to_load = []
models_already_loaded = [] models_already_loaded = []
for x in models: for x in models:
@ -401,8 +412,8 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
@ -441,11 +452,15 @@ def load_models_gpu(models, memory_required=0):
def load_model_gpu(model): def load_model_gpu(model):
return load_models_gpu([model]) return load_models_gpu([model])
def cleanup_models(): def cleanup_models(keep_clone_weights_loaded=False):
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2: if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
for i in to_delete: for i in to_delete:
x = current_loaded_models.pop(i) x = current_loaded_models.pop(i)
@ -602,7 +617,8 @@ def supports_dtype(device, dtype): #TODO
def device_supports_non_blocking(device): def device_supports_non_blocking(device):
if is_device_mps(device): if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking return False #pytorch bug? mps doesn't support non blocking
return True return False
# return True #TODO: figure out why this causes issues
def cast_to_device(tensor, device, dtype, copy=False): def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False device_supports_cast = False

View File

@ -7,6 +7,38 @@ import uuid
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
def apply_weight_decompose(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
.transpose(0, 1)
)
return weight * (dora_scale / weight_norm)
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
if "patches_replace" not in to:
to["patches_replace"] = {}
else:
to["patches_replace"] = to["patches_replace"].copy()
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
else:
to["patches_replace"][name] = to["patches_replace"][name].copy()
if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
model_options["transformer_options"] = to
return model_options
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size self.size = size
@ -97,16 +129,7 @@ class ModelPatcher:
to["patches"][name] = to["patches"].get(name, []) + [patch] to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
to = self.model_options["transformer_options"] self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
def set_model_attn1_patch(self, patch): def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch") self.set_model_patch(patch, "attn1_patch")
@ -138,6 +161,15 @@ class ModelPatcher:
def add_object_patch(self, name, obj): def add_object_patch(self, name, obj):
self.object_patches[name] = obj self.object_patches[name] = obj
def get_model_object(self, name):
if name in self.object_patches:
return self.object_patches[name]
else:
if name in self.object_patches_backup:
return self.object_patches_backup[name]
else:
return comfy.utils.get_attr(self.model, name)
def model_patches_to(self, device): def model_patches_to(self, device):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" in to: if "patches" in to:
@ -266,7 +298,7 @@ class ModelPatcher:
if weight_key in self.patches: if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self) m.weight_function = LowVramPatch(weight_key, self)
if bias_key in self.patches: if bias_key in self.patches:
m.bias_function = LowVramPatch(weight_key, self) m.bias_function = LowVramPatch(bias_key, self)
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
@ -309,6 +341,7 @@ class ModelPatcher:
elif patch_type == "lora": #lora/locon elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None: if v[2] is not None:
alpha *= v[2] / mat2.shape[0] alpha *= v[2] / mat2.shape[0]
if v[3] is not None: if v[3] is not None:
@ -318,6 +351,8 @@ class ModelPatcher:
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try: try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr": elif patch_type == "lokr":
@ -328,6 +363,7 @@ class ModelPatcher:
w2_a = v[5] w2_a = v[5]
w2_b = v[6] w2_b = v[6]
t2 = v[7] t2 = v[7]
dora_scale = v[8]
dim = None dim = None
if w1 is None: if w1 is None:
@ -357,6 +393,8 @@ class ModelPatcher:
try: try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha": elif patch_type == "loha":
@ -366,6 +404,7 @@ class ModelPatcher:
alpha *= v[2] / w1b.shape[0] alpha *= v[2] / w1b.shape[0]
w2a = v[3] w2a = v[3]
w2b = v[4] w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition if v[5] is not None: #cp decomposition
t1 = v[5] t1 = v[5]
t2 = v[6] t2 = v[6]
@ -386,12 +425,16 @@ class ModelPatcher:
try: try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora": elif patch_type == "glora":
if v[4] is not None: if v[4] is not None:
alpha *= v[4] / v[0].shape[0] alpha *= v[4] / v[0].shape[0]
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
@ -399,6 +442,8 @@ class ModelPatcher:
try: try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
else: else:
@ -437,4 +482,4 @@ class ModelPatcher:
for k in keys: for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {} self.object_patches_backup.clear()

View File

@ -1,10 +1,9 @@
import torch import torch
import comfy.model_management import comfy.model_management
import comfy.samplers import comfy.samplers
import comfy.conds
import comfy.utils import comfy.utils
import math
import numpy as np import numpy as np
import logging
def prepare_noise(latent_image, seed, noise_inds=None): def prepare_noise(latent_image, seed, noise_inds=None):
""" """
@ -25,94 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
noises = torch.cat(noises, axis=0) noises = torch.cat(noises, axis=0)
return noises return noises
def prepare_mask(noise_mask, shape, device): def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
"""ensures noise mask is of proper dimensions""" logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") return model, positive, negative, noise_mask, []
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models): def cleanup_additional_models(models):
"""cleanup additional models that were loaded""" logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device
positive = convert_cond(positive)
negative = convert_cond(negative)
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, positive, negative, noise_mask, models
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
noise = noise.to(model.load_device) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
latent_image = latent_image.to(model.load_device)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sigmas = sigmas.to(model.load_device)
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples

76
comfy/sampler_helpers.py Normal file
View File

@ -0,0 +1,76 @@
import torch
import comfy.model_management
import comfy.conds
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, conds):
device = model.load_device
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, conds, models
def cleanup_models(conds, models):
cleanup_additional_models(models)
control_cleanup = []
for k in conds:
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))

View File

@ -5,6 +5,7 @@ import collections
from comfy import model_management from comfy import model_management
import math import math
import logging import logging
import comfy.sampler_helpers
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
@ -127,30 +128,23 @@ def cond_cat(c_list):
return out return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): def calc_cond_batch(model, conds, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in) out_conds = []
out_count = torch.ones_like(x_in) * 1e-37 out_counts = []
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = [] to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)] for i in range(len(conds)):
if uncond is not None: out_conds.append(torch.zeros_like(x_in))
for x in uncond: out_counts.append(torch.ones_like(x_in) * 1e-37)
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)] cond = conds[i]
if cond is not None:
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, i)]
while len(to_run) > 0: while len(to_run) > 0:
first = to_run[0] first = to_run[0]
@ -222,74 +216,66 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else: else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks): for o in range(batch_chunks):
if cond_or_uncond[o] == COND: cond_index = cond_or_uncond[o]
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count for i in range(len(out_conds)):
del out_count out_conds[i] /= out_counts[i]
out_uncond /= out_uncond_count
del out_uncond_count return out_conds
return out_cond, out_uncond
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None uncond_ = None
else: else:
uncond_ = uncond uncond_ = uncond
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) conds = [cond, uncond_]
if "sampler_cfg_function" in model_options: out = calc_cond_batch(model, conds, x, timestep, model_options)
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result class KSamplerX0Inpaint:
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model, sigmas): def __init__(self, model, sigmas):
super().__init__()
self.inner_model = model self.inner_model = model
self.sigmas = sigmas self.sigmas = sigmas
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None: if denoise_mask is not None:
if "denoise_mask_function" in model_options: if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None: if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask out = out * denoise_mask + self.latent_image * latent_mask
return out return out
def simple_scheduler(model, steps): def simple_scheduler(model_sampling, steps):
s = model.model_sampling s = model_sampling
sigs = [] sigs = []
ss = len(s.sigmas) / steps ss = len(s.sigmas) / steps
for x in range(steps): for x in range(steps):
@ -297,8 +283,8 @@ def simple_scheduler(model, steps):
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps): def ddim_scheduler(model_sampling, steps):
s = model.model_sampling s = model_sampling
sigs = [] sigs = []
ss = max(len(s.sigmas) // steps, 1) ss = max(len(s.sigmas) // steps, 1)
x = 1 x = 1
@ -309,8 +295,8 @@ def ddim_scheduler(model, steps):
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def normal_scheduler(model, steps, sgm=False, floor=False): def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
s = model.model_sampling s = model_sampling
start = s.timestep(s.sigma_max) start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min) end = s.timestep(s.sigma_min)
@ -571,61 +557,122 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options) return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
return model_denoise
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
positive = positive[:] for k in conds:
negative = negative[:] conds[k] = conds[k][:]
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) for k in conds:
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) calculate_start_end_timesteps(model, conds[k])
model_wrap = wrap_model(model)
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'): if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) for k in conds:
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for k in conds:
create_cond_with_same_area_if_none(negative, c) for c in conds[k]:
for c in negative: for kk in conds:
create_cond_with_same_area_if_none(positive, c) if k != kk:
create_cond_with_same_area_if_none(conds[kk], c)
pre_run_control(model, negative + positive) for k in conds:
pre_run_control(model, conds[k])
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) if "positive" in conds:
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) positive = conds["positive"]
for k in conds:
if k != "positive":
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} return conds
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
def set_conds(self, positive, negative):
self.inner_set_conds({"positive": positive, "negative": negative})
def set_cfg(self, cfg):
self.cfg = cfg
def inner_set_conds(self, conds):
for k in conds:
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": self.model_options, "seed":seed}
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.conds
del self.loaded_models
return output
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model)
cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg)
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps): def calculate_sigmas(model_sampling, scheduler_name, steps):
if scheduler_name == "karras": if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential": elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "normal": elif scheduler_name == "normal":
sigmas = normal_scheduler(model, steps) sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple": elif scheduler_name == "simple":
sigmas = simple_scheduler(model, steps) sigmas = simple_scheduler(model_sampling, steps)
elif scheduler_name == "ddim_uniform": elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model, steps) sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True) sigmas = normal_scheduler(model_sampling, steps, sgm=True)
else: else:
logging.error("error invalid scheduler {}".format(scheduler_name)) logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas return sigmas
@ -667,7 +714,7 @@ class KSampler:
steps += 1 steps += 1
discard_penultimate_sigma = True discard_penultimate_sigma = True
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)
if discard_penultimate_sigma: if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
@ -678,9 +725,12 @@ class KSampler:
if denoise is None or denoise > 0.9999: if denoise is None or denoise > 0.9999:
self.sigmas = self.calculate_sigmas(steps).to(self.device) self.sigmas = self.calculate_sigmas(steps).to(self.device)
else: else:
new_steps = int(steps/denoise) if denoise <= 0.0:
sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = torch.FloatTensor([])
self.sigmas = sigmas[-(steps + 1):] else:
new_steps = int(steps/denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
if sigmas is None: if sigmas is None:

View File

@ -214,12 +214,18 @@ class VAE:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4] ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4 self.downscale_ratio = 4
self.upscale_ratio = 4 self.upscale_ratio = 4
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
else: else:
self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval() self.first_stage_model = self.first_stage_model.eval()
@ -600,7 +606,7 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None clip_sd = None
load_models = [model] load_models = [model]
if clip is not None: if clip is not None:
@ -610,4 +616,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
model_management.load_models_gpu(load_models) model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]
comfy.utils.save_torch_file(sd, output_path, metadata=metadata) comfy.utils.save_torch_file(sd, output_path, metadata=metadata)

View File

@ -70,8 +70,8 @@ class SD20(supported_models_base.BASE):
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k] out = state_dict.get(k, None)
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return model_base.ModelType.V_PREDICTION return model_base.ModelType.V_PREDICTION
return model_base.ModelType.EPS return model_base.ModelType.EPS
@ -174,6 +174,11 @@ class SDXL(supported_models_base.BASE):
self.sampling_settings["sigma_max"] = 80.0 self.sampling_settings["sigma_max"] = 80.0
self.sampling_settings["sigma_min"] = 0.002 self.sampling_settings["sigma_min"] = 0.002
return model_base.ModelType.EDM return model_base.ModelType.EDM
elif "edm_vpred.sigma_max" in state_dict:
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
if "edm_vpred.sigma_min" in state_dict:
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
return model_base.ModelType.V_PREDICTION_EDM
elif "v_pred" in state_dict: elif "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION return model_base.ModelType.V_PREDICTION
else: else:
@ -334,6 +339,11 @@ class Stable_Zero123(supported_models_base.BASE):
"num_head_channels": -1, "num_head_channels": -1,
} }
required_keys = {
"cc_projection.weight": None,
"cc_projection.bias": None,
}
clip_vision_prefix = "cond_stage_model.model.visual." clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
@ -439,6 +449,33 @@ class Stable_Cascade_B(Stable_Cascade_C):
out = model_base.StableCascade_B(self, device=device) out = model_base.StableCascade_B(self, device=device)
return out return out
class SD15_instructpix2pix(SD15):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SD15_instructpix2pix(self, device=device)
class SDXL_instructpix2pix(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -16,6 +16,8 @@ class BASE:
"num_head_channels": 64, "num_head_channels": 64,
} }
required_keys = {}
clip_prefix = [] clip_prefix = []
clip_vision_prefix = None clip_vision_prefix = None
noise_aug_config = None noise_aug_config = None
@ -28,10 +30,14 @@ class BASE:
manual_cast_dtype = None manual_cast_dtype = None
@classmethod @classmethod
def matches(s, unet_config): def matches(s, unet_config, state_dict=None):
for k in s.unet_config: for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]: if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False return False
if state_dict is not None:
for k in s.required_keys:
if k not in state_dict:
return False
return True return True
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
@ -41,7 +47,8 @@ class BASE:
return self.unet_config["in_channels"] > 4 return self.unet_config["in_channels"] > 4
def __init__(self, unet_config): def __init__(self, unet_config):
self.unet_config = unet_config self.unet_config = unet_config.copy()
self.sampling_settings = self.sampling_settings.copy()
self.latent_format = self.latent_format() self.latent_format = self.latent_format()
for x in self.unet_extra_config: for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x] self.unet_config[x] = self.unet_extra_config[x]

View File

@ -0,0 +1,45 @@
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
import numpy as np
import torch
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
"""
xs = np.linspace(0, 1, len(t_steps))
ys = np.log(t_steps[::-1])
new_xs = np.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = np.exp(new_ys)[::-1].copy()
return interped_ys
NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
class AlignYourStepsScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model_type": (["SD1", "SDXL", "SVD"], ),
"steps": ("INT", {"default": 10, "min": 10, "max": 10000}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model_type, steps):
sigmas = NOISE_LEVELS[model_type][:]
if (steps + 1) != len(sigmas):
sigmas = loglinear_interp(sigmas, steps + 1)
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
NODE_CLASS_MAPPINGS = {
"AlignYourStepsScheduler": AlignYourStepsScheduler,
}

View File

@ -1,4 +1,3 @@
#From https://github.com/kornia/kornia
import math import math
import torch import torch

View File

@ -8,7 +8,7 @@ class CLIPTextEncodeSDXLRefiner:
"ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"text": ("STRING", {"multiline": True}), "clip": ("CLIP", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode" FUNCTION = "encode"
@ -30,8 +30,8 @@ class CLIPTextEncodeSDXL:
"crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
"target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ), "text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
"text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ), "text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode" FUNCTION = "encode"

View File

@ -3,7 +3,7 @@
class CLIPTextEncodeControlnet: class CLIPTextEncodeControlnet:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}} return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode" FUNCTION = "encode"

View File

@ -4,6 +4,7 @@ from comfy.k_diffusion import sampling as k_diffusion_sampling
import latent_preview import latent_preview
import torch import torch
import comfy.utils import comfy.utils
import node_helpers
class BasicScheduler: class BasicScheduler:
@ -24,10 +25,11 @@ class BasicScheduler:
def get_sigmas(self, model, scheduler, steps, denoise): def get_sigmas(self, model, scheduler, steps, denoise):
total_steps = steps total_steps = steps
if denoise < 1.0: if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = int(steps/denoise) total_steps = int(steps/denoise)
comfy.model_management.load_models_gpu([model]) sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):] sigmas = sigmas[-(steps + 1):]
return (sigmas, ) return (sigmas, )
@ -160,6 +162,9 @@ class FlipSigmas:
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
def get_sigmas(self, sigmas): def get_sigmas(self, sigmas):
if len(sigmas) == 0:
return (sigmas,)
sigmas = sigmas.flip(0) sigmas = sigmas.flip(0)
if sigmas[0] == 0: if sigmas[0] == 0:
sigmas[0] = 0.0001 sigmas[0] = 0.0001
@ -310,6 +315,24 @@ class SamplerDPMAdaptative:
"s_noise":s_noise }) "s_noise":s_noise })
return (sampler, ) return (sampler, )
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
def generate_noise(self, input_latent):
latent_image = input_latent["samples"]
return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
class Noise_RandomNoise:
def __init__(self, seed):
self.seed = seed
def generate_noise(self, input_latent):
latent_image = input_latent["samples"]
batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
class SamplerCustom: class SamplerCustom:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -337,10 +360,9 @@ class SamplerCustom:
latent = latent_image latent = latent_image
latent_image = latent["samples"] latent_image = latent["samples"]
if not add_noise: if not add_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = Noise_EmptyNoise().generate_noise(latent)
else: else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds)
noise_mask = None noise_mask = None
if "noise_mask" in latent: if "noise_mask" in latent:
@ -361,6 +383,207 @@ class SamplerCustom:
out_denoised = out out_denoised = out
return (out, out_denoised) return (out, out_denoised)
class Guider_Basic(comfy.samplers.CFGGuider):
def set_conds(self, positive):
self.inner_set_conds({"positive": positive})
class BasicGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"conditioning": ("CONDITIONING", ),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, conditioning):
guider = Guider_Basic(model)
guider.set_conds(conditioning)
return (guider,)
class CFGGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, positive, negative, cfg):
guider = comfy.samplers.CFGGuider(model)
guider.set_conds(positive, negative)
guider.set_cfg(cfg)
return (guider,)
class Guider_DualCFG(comfy.samplers.CFGGuider):
def set_cfg(self, cfg1, cfg2):
self.cfg1 = cfg1
self.cfg2 = cfg2
def set_conds(self, positive, middle, negative):
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
def predict_noise(self, x, timestep, model_options={}, seed=None):
negative_cond = self.conds.get("negative", None)
middle_cond = self.conds.get("middle", None)
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
class DualCFGGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"cond1": ("CONDITIONING", ),
"cond2": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative):
guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative)
return (guider,)
class DisableNoise:
@classmethod
def INPUT_TYPES(s):
return {"required":{
}
}
RETURN_TYPES = ("NOISE",)
FUNCTION = "get_noise"
CATEGORY = "sampling/custom_sampling/noise"
def get_noise(self):
return (Noise_EmptyNoise(),)
class RandomNoise(DisableNoise):
@classmethod
def INPUT_TYPES(s):
return {"required":{
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
}
}
def get_noise(self, noise_seed):
return (Noise_RandomNoise(noise_seed),)
class SamplerCustomAdvanced:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"noise": ("NOISE", ),
"guider": ("GUIDER", ),
"sampler": ("SAMPLER", ),
"sigmas": ("SIGMAS", ),
"latent_image": ("LATENT", ),
}
}
RETURN_TYPES = ("LATENT","LATENT")
RETURN_NAMES = ("output", "denoised_output")
FUNCTION = "sample"
CATEGORY = "sampling/custom_sampling"
def sample(self, noise, guider, sampler, sigmas, latent_image):
latent = latent_image
latent_image = latent["samples"]
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
x0_output = {}
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
samples = samples.to(comfy.model_management.intermediate_device())
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
out_denoised = latent.copy()
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
else:
out_denoised = out
return (out, out_denoised)
class AddNoise:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"noise": ("NOISE", ),
"sigmas": ("SIGMAS", ),
"latent_image": ("LATENT", ),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "add_noise"
CATEGORY = "_for_testing/custom_sampling/noise"
def add_noise(self, model, noise, sigmas, latent_image):
if len(sigmas) == 0:
return latent_image
latent = latent_image
latent_image = latent["samples"]
noisy = noise.generate_noise(latent)
model_sampling = model.get_model_object("model_sampling")
process_latent_out = model.get_model_object("process_latent_out")
process_latent_in = model.get_model_object("process_latent_in")
if len(sigmas) > 1:
scale = torch.abs(sigmas[0] - sigmas[-1])
else:
scale = sigmas[0]
if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = process_latent_in(latent_image)
noisy = model_sampling.noise_scaling(scale, noisy, latent_image)
noisy = process_latent_out(noisy)
noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0)
out = latent.copy()
out["samples"] = noisy
return (out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SamplerCustom": SamplerCustom, "SamplerCustom": SamplerCustom,
"BasicScheduler": BasicScheduler, "BasicScheduler": BasicScheduler,
@ -378,4 +601,12 @@ NODE_CLASS_MAPPINGS = {
"SamplerDPMAdaptative": SamplerDPMAdaptative, "SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas, "SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas, "FlipSigmas": FlipSigmas,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,
"BasicGuider": BasicGuider,
"RandomNoise": RandomNoise,
"DisableNoise": DisableNoise,
"AddNoise": AddNoise,
"SamplerCustomAdvanced": SamplerCustomAdvanced,
} }

View File

@ -0,0 +1,45 @@
import torch
class InstructPixToPixConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/instructpix2pix"
def encode(self, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
concat_latent = vae.encode(pixels)
out_latent = {}
out_latent["samples"] = torch.zeros_like(concat_latent)
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1], out_latent)
NODE_CLASS_MAPPINGS = {
"InstructPixToPixConditioning": InstructPixToPixConditioning,
}

View File

@ -2,7 +2,9 @@ import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_base import comfy.model_base
import comfy.model_management import comfy.model_management
import comfy.model_sampling
import torch
import folder_paths import folder_paths
import json import json
import os import os
@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting" # "v2-inpainting"
extra_keys = {}
model_sampling = model.get_model_object("model_sampling")
if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM):
if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION):
extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float()
extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float()
if model.model.model_type == comfy.model_base.ModelType.EPS: if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon" metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave: class CheckpointSave:
def __init__(self): def __init__(self):

View File

@ -0,0 +1,60 @@
import comfy_extras.nodes_model_merging
class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
for i in range(12):
arg_dict["input_blocks.{}.".format(i)] = argument
for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
for i in range(12):
arg_dict["output_blocks.{}.".format(i)] = argument
arg_dict["out."] = argument
return {"required": arg_dict}
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
for i in range(9):
arg_dict["input_blocks.{}".format(i)] = argument
for i in range(3):
arg_dict["middle_block.{}".format(i)] = argument
for i in range(9):
arg_dict["output_blocks.{}".format(i)] = argument
arg_dict["out."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
}

56
comfy_extras/nodes_pag.py Normal file
View File

@ -0,0 +1,56 @@
#Modified/simplified version of the node from: https://github.com/pamparamm/sd-perturbed-attention
#If you want the one with more options see the above repo.
#My modified one here is more basic but has less chances of breaking with ComfyUI updates.
import comfy.model_patcher
import comfy.samplers
class PerturbedAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, scale):
unet_block = "middle"
unet_block_id = 0
m = model.clone()
def perturbed_attention(q, k, v, extra_options, mask=None):
return v
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
if scale == 0:
return cfg_result
# Replace Self-attention with PAG
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, perturbed_attention, "attn1", unet_block, unet_block_id)
(pag,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
return cfg_result + (cond_pred - pag) * scale
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m,)
NODE_CLASS_MAPPINGS = {
"PerturbedAttentionGuidance": PerturbedAttentionGuidance,
}

View File

@ -1,10 +1,20 @@
import torch import torch
import comfy.model_management import comfy.model_management
import comfy.sample import comfy.sampler_helpers
import comfy.samplers import comfy.samplers
import comfy.utils import comfy.utils
import node_helpers
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond
perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos
perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
return cfg_result
#TODO: This node should be removed, it has been replaced with PerpNegGuider
class PerpNeg: class PerpNeg:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -19,7 +29,7 @@ class PerpNeg:
def patch(self, model, empty_conditioning, neg_scale): def patch(self, model, empty_conditioning, neg_scale):
m = model.clone() m = model.clone()
nocond = comfy.sample.convert_cond(empty_conditioning) nocond = comfy.sampler_helpers.convert_cond(empty_conditioning)
def cfg_function(args): def cfg_function(args):
model = args["model"] model = args["model"]
@ -31,14 +41,9 @@ class PerpNeg:
model_options = args["model_options"] model_options = args["model_options"]
nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale)
neg = noise_pred_neg - noise_pred_nocond
perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos
perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
cfg_result = x - cfg_result
return cfg_result return cfg_result
m.set_model_sampler_cfg_function(cfg_function) m.set_model_sampler_cfg_function(cfg_function)
@ -46,10 +51,52 @@ class PerpNeg:
return (m, ) return (m, )
class Guider_PerpNeg(comfy.samplers.CFGGuider):
def set_conds(self, positive, negative, empty_negative_prompt):
empty_negative_prompt = node_helpers.conditioning_set_values(empty_negative_prompt, {"prompt_type": "negative"})
self.inner_set_conds({"positive": positive, "empty_negative_prompt": empty_negative_prompt, "negative": negative})
def set_cfg(self, cfg, neg_scale):
self.cfg = cfg
self.neg_scale = neg_scale
def predict_noise(self, x, timestep, model_options={}, seed=None):
positive_cond = self.conds.get("positive", None)
negative_cond = self.conds.get("negative", None)
empty_cond = self.conds.get("empty_negative_prompt", None)
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, positive_cond, empty_cond], x, timestep, model_options)
return perp_neg(x, out[1], out[0], out[2], self.neg_scale, self.cfg)
class PerpNegGuider:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"empty_conditioning": ("CONDITIONING", ),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}
}
RETURN_TYPES = ("GUIDER",)
FUNCTION = "get_guider"
CATEGORY = "_for_testing"
def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale):
guider = Guider_PerpNeg(model)
guider.set_conds(positive, negative, empty_conditioning)
guider.set_cfg(cfg, neg_scale)
return (guider,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg, "PerpNeg": PerpNeg,
"PerpNegGuider": PerpNegGuider,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"PerpNeg": "Perp-Neg", "PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)",
} }

View File

@ -141,7 +141,7 @@ class PhotoMakerEncode:
return {"required": { "photomaker": ("PHOTOMAKER",), return {"required": { "photomaker": ("PHOTOMAKER",),
"image": ("IMAGE",), "image": ("IMAGE",),
"clip": ("CLIP", ), "clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}), "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)

View File

@ -5,6 +5,7 @@ from PIL import Image
import math import math
import comfy.utils import comfy.utils
import comfy.model_management
class Blend: class Blend:
@ -102,6 +103,7 @@ class Blur:
if blur_radius == 0: if blur_radius == 0:
return (image,) return (image,)
image = image.to(comfy.model_management.get_torch_device())
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1 kernel_size = blur_radius * 2 + 1
@ -112,7 +114,7 @@ class Blur:
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
blurred = blurred.permute(0, 2, 3, 1) blurred = blurred.permute(0, 2, 3, 1)
return (blurred,) return (blurred.to(comfy.model_management.intermediate_device()),)
class Quantize: class Quantize:
def __init__(self): def __init__(self):
@ -204,13 +206,13 @@ class Sharpen:
"default": 1.0, "default": 1.0,
"min": 0.1, "min": 0.1,
"max": 10.0, "max": 10.0,
"step": 0.1 "step": 0.01
}), }),
"alpha": ("FLOAT", { "alpha": ("FLOAT", {
"default": 1.0, "default": 1.0,
"min": 0.0, "min": 0.0,
"max": 5.0, "max": 5.0,
"step": 0.1 "step": 0.01
}), }),
}, },
} }
@ -225,6 +227,7 @@ class Sharpen:
return (image,) return (image,)
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
image = image.to(comfy.model_management.get_torch_device())
kernel_size = sharpen_radius * 2 + 1 kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
@ -239,7 +242,7 @@ class Sharpen:
result = torch.clamp(sharpened, 0, 1) result = torch.clamp(sharpened, 0, 1)
return (result,) return (result.to(comfy.model_management.intermediate_device()),)
class ImageScaleToTotalPixels: class ImageScaleToTotalPixels:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]

View File

@ -150,7 +150,7 @@ class SelfAttentionGuidance:
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred degraded_noised = degraded + x - uncond_pred
# call into the UNet # call into the UNet
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) (sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)

View File

@ -47,7 +47,7 @@ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeFor
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000", "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630" "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
} }
def cuda_malloc_supported(): def cuda_malloc_supported():

View File

@ -393,7 +393,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
for name, inputs in input_data_all.items(): for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs] input_data_formatted[name] = [format_value(x) for x in inputs]
logging.error("!!! Exception during processing !!!") logging.error(f"!!! Exception during processing !!! {ex}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
error_details = { error_details = {
@ -473,6 +473,7 @@ class PromptExecutor:
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached", self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id}, { "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False) broadcast=False)

View File

@ -1,7 +1,8 @@
import os import os
import time import time
import logging
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
folder_names_and_paths = {} folder_names_and_paths = {}
@ -44,7 +45,7 @@ if not os.path.exists(input_directory):
try: try:
os.makedirs(input_directory) os.makedirs(input_directory)
except: except:
print("Failed to create input directory") logging.error("Failed to create input directory")
def set_output_directory(output_dir): def set_output_directory(output_dir):
global output_directory global output_directory
@ -146,8 +147,9 @@ def recursive_search(directory, excluded_dir_names=None):
try: try:
dirs[directory] = os.path.getmtime(directory) dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError: except FileNotFoundError:
print(f"Warning: Unable to access {directory}. Skipping this path.") logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
logging.debug("recursive file list on directory {}".format(directory))
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames: for file_name in filenames:
@ -159,8 +161,9 @@ def recursive_search(directory, excluded_dir_names=None):
try: try:
dirs[path] = os.path.getmtime(path) dirs[path] = os.path.getmtime(path)
except FileNotFoundError: except FileNotFoundError:
print(f"Warning: Unable to access {path}. Skipping this path.") logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue continue
logging.debug("found {} files".format(len(result)))
return result, dirs return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
@ -178,6 +181,8 @@ def get_full_path(folder_name, filename):
full_path = os.path.join(x, filename) full_path = os.path.join(x, filename)
if os.path.isfile(full_path): if os.path.isfile(full_path):
return full_path return full_path
elif os.path.islink(full_path):
logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path))
return None return None
@ -249,7 +254,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \ "\n full_output_folder: " + os.path.abspath(full_output_folder) + \
"\n output_dir: " + output_dir + \ "\n output_dir: " + output_dir + \
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
print(err) logging.error(err)
raise Exception(err) raise Exception(err)
try: try:

10
node_helpers.py Normal file
View File

@ -0,0 +1,10 @@
def conditioning_set_values(conditioning, values={}):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
n[1][k] = values[k]
c.append(n)
return c

View File

@ -34,6 +34,7 @@ import importlib
import folder_paths import folder_paths
import latent_preview import latent_preview
import node_helpers
def before_node_execution(): def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
@ -41,12 +42,12 @@ def before_node_execution():
def interrupt_processing(value=True): def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value) comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192 MAX_RESOLUTION=16384
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}} return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", )}}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode" FUNCTION = "encode"
@ -151,13 +152,9 @@ class ConditioningSetArea:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength): def append(self, conditioning, width, height, x, y, strength):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8),
for t in conditioning: "strength": strength,
n = [t[0], t[1].copy()] "set_area_to_bounds": False})
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
c.append(n)
return (c, ) return (c, )
class ConditioningSetAreaPercentage: class ConditioningSetAreaPercentage:
@ -176,13 +173,9 @@ class ConditioningSetAreaPercentage:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength): def append(self, conditioning, width, height, x, y, strength):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x),
for t in conditioning: "strength": strength,
n = [t[0], t[1].copy()] "set_area_to_bounds": False})
n[1]['area'] = ("percentage", height, width, y, x)
n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
c.append(n)
return (c, ) return (c, )
class ConditioningSetAreaStrength: class ConditioningSetAreaStrength:
@ -197,11 +190,7 @@ class ConditioningSetAreaStrength:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, strength): def append(self, conditioning, strength):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"strength": strength})
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['strength'] = strength
c.append(n)
return (c, ) return (c, )
@ -219,19 +208,15 @@ class ConditioningSetMask:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def append(self, conditioning, mask, set_cond_area, strength): def append(self, conditioning, mask, set_cond_area, strength):
c = []
set_area_to_bounds = False set_area_to_bounds = False
if set_cond_area != "default": if set_cond_area != "default":
set_area_to_bounds = True set_area_to_bounds = True
if len(mask.shape) < 3: if len(mask.shape) < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
for t in conditioning:
n = [t[0], t[1].copy()] c = node_helpers.conditioning_set_values(conditioning, {"mask": mask,
_, h, w = mask.shape "set_area_to_bounds": set_area_to_bounds,
n[1]['mask'] = mask "mask_strength": strength})
n[1]['set_area_to_bounds'] = set_area_to_bounds
n[1]['mask_strength'] = strength
c.append(n)
return (c, ) return (c, )
class ConditioningZeroOut: class ConditioningZeroOut:
@ -266,13 +251,8 @@ class ConditioningSetTimestepRange:
CATEGORY = "advanced/conditioning" CATEGORY = "advanced/conditioning"
def set_range(self, conditioning, start, end): def set_range(self, conditioning, start, end):
c = [] c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start,
for t in conditioning: "end_percent": end})
d = t[1].copy()
d['start_percent'] = start
d['end_percent'] = end
n = [t[0], d]
c.append(n)
return (c, ) return (c, )
class VAEDecode: class VAEDecode:
@ -413,13 +393,8 @@ class InpaintModelConditioning:
out = [] out = []
for conditioning in [positive, negative]: for conditioning in [positive, negative]:
c = [] c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
for t in conditioning: "concat_mask": mask})
d = t[1].copy()
d["concat_latent_image"] = concat_latent
d["concat_mask"] = mask
n = [t[0], d]
c.append(n)
out.append(c) out.append(c)
return (out[0], out[1], out_latent) return (out[0], out[1], out_latent)
@ -991,7 +966,7 @@ class GLIGENTextBoxApply:
return {"required": {"conditioning_to": ("CONDITIONING", ), return {"required": {"conditioning_to": ("CONDITIONING", ),
"clip": ("CLIP", ), "clip": ("CLIP", ),
"gligen_textbox_model": ("GLIGEN", ), "gligen_textbox_model": ("GLIGEN", ),
"text": ("STRING", {"multiline": True}), "text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
@ -1876,6 +1851,7 @@ def load_custom_node(module_path, ignore=set()):
sp = os.path.splitext(module_path) sp = os.path.splitext(module_path)
module_name = sp[0] module_name = sp[0]
try: try:
logging.debug("Trying to load custom node {}".format(module_path))
if os.path.isfile(module_path): if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path) module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module_dir = os.path.split(module_path)[0] module_dir = os.path.split(module_path)[0]
@ -1964,6 +1940,10 @@ def init_custom_nodes():
"nodes_morphology.py", "nodes_morphology.py",
"nodes_stable_cascade.py", "nodes_stable_cascade.py",
"nodes_differential_diffusion.py", "nodes_differential_diffusion.py",
"nodes_ip2p.py",
"nodes_model_merging_model_specific.py",
"nodes_pag.py",
"nodes_align_your_steps.py",
] ]
import_failed = [] import_failed = []

View File

@ -949,7 +949,7 @@ describe("group node", () => {
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max expect(p2.widgets.value.widget.options?.max).toBe(16384); // width/height max
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p1.widgets.value.value).toBe(128); expect(p1.widgets.value.value).toBe(128);

View File

@ -204,13 +204,17 @@ export class EzWidget {
convertToWidget() { convertToWidget() {
if (!this.isConvertedToInput) if (!this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already a widget.`); throw new Error(`Widget ${this.widget.name} cannot be converted as it is already a widget.`);
this.node.menu[`Convert ${this.widget.name} to widget`].call(); var menu = this.node.menu["Convert Input to Widget"].item.submenu.options;
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to widget`);
menu[index].callback.call();
} }
convertToInput() { convertToInput() {
if (this.isConvertedToInput) if (this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already an input.`); throw new Error(`Widget ${this.widget.name} cannot be converted as it is already an input.`);
this.node.menu[`Convert ${this.widget.name} to input`].call(); var menu = this.node.menu["Convert Widget to Input"].item.submenu.options;
var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to input`);
menu[index].callback.call();
} }
} }

View File

@ -20,6 +20,10 @@ const colorPalettes = {
"MODEL": "#B39DDB", // light lavender-purple "MODEL": "#B39DDB", // light lavender-purple
"STYLE_MODEL": "#C2FFAE", // light green-yellow "STYLE_MODEL": "#C2FFAE", // light green-yellow
"VAE": "#FF6E6E", // bright red "VAE": "#FF6E6E", // bright red
"NOISE": "#B0B0B0", // gray
"GUIDER": "#66FFFF", // cyan
"SAMPLER": "#ECB4B4", // very soft red
"SIGMAS": "#CDFFCD", // soft lime green
"TAESD": "#DCC274", // cheesecake "TAESD": "#DCC274", // cheesecake
}, },
"litegraph_base": { "litegraph_base": {

View File

@ -17,7 +17,7 @@ app.registerExtension({
// Locate dynamic prompt text widgets // Locate dynamic prompt text widgets
// Include any widgets with dynamicPrompts set to true, and customtext // Include any widgets with dynamicPrompts set to true, and customtext
const widgets = node.widgets.filter( const widgets = node.widgets.filter(
(n) => (n.type === "customtext" && n.dynamicPrompts !== false) || n.dynamicPrompts (n) => n.dynamicPrompts
); );
for (const widget of widgets) { for (const widget of widgets) {
// Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node // Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node

View File

@ -256,8 +256,18 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
return { customConfig }; return { customConfig };
} }
let useConversionSubmenusSetting;
app.registerExtension({ app.registerExtension({
name: "Comfy.WidgetInputs", name: "Comfy.WidgetInputs",
init() {
useConversionSubmenusSetting = app.ui.settings.addSetting({
id: "Comfy.NodeInputConversionSubmenus",
name: "Node widget/input conversion sub-menus",
tooltip: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
type: "boolean",
defaultValue: true,
});
},
async beforeRegisterNodeDef(nodeType, nodeData, app) { async beforeRegisterNodeDef(nodeType, nodeData, app) {
// Add menu options to conver to/from widgets // Add menu options to conver to/from widgets
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
@ -292,12 +302,31 @@ app.registerExtension({
} }
} }
} }
if (toInput.length) {
options.push(...toInput, null);
}
//Convert.. main menu
if (toInput.length) {
if (useConversionSubmenusSetting.value) {
options.push({
content: "Convert Widget to Input",
submenu: {
options: toInput,
},
});
} else {
options.push(...toInput, null);
}
}
if (toWidget.length) { if (toWidget.length) {
options.push(...toWidget, null); if (useConversionSubmenusSetting.value) {
options.push({
content: "Convert Input to Widget",
submenu: {
options: toWidget,
},
});
} else {
options.push(...toWidget, null);
}
} }
} }

View File

@ -7247,7 +7247,7 @@ LGraphNode.prototype.executeAction = function(action)
//create links //create links
for (var i = 0; i < clipboard_info.links.length; ++i) { for (var i = 0; i < clipboard_info.links.length; ++i) {
var link_info = clipboard_info.links[i]; var link_info = clipboard_info.links[i];
var origin_node; var origin_node = undefined;
var origin_node_relative_id = link_info[0]; var origin_node_relative_id = link_info[0];
if (origin_node_relative_id != null) { if (origin_node_relative_id != null) {
origin_node = nodes[origin_node_relative_id]; origin_node = nodes[origin_node_relative_id];

View File

@ -170,9 +170,12 @@ export async function importA1111(graph, parameters) {
const opts = parameters const opts = parameters
.substr(p) .substr(p)
.split("\n")[1] .split("\n")[1]
.split(",") .match(new RegExp("\\s*([^:]+:\\s*([^\"\\{].*?|\".*?\"|\\{.*?\\}))\\s*(,|$)", "g"))
.reduce((p, n) => { .reduce((p, n) => {
const s = n.split(":"); const s = n.split(":");
if (s[1].endsWith(',')) {
s[1] = s[1].substr(0, s[1].length -1);
}
p[s[0].trim().toLowerCase()] = s[1].trim(); p[s[0].trim().toLowerCase()] = s[1].trim();
return p; return p;
}, {}); }, {});
@ -191,6 +194,7 @@ export async function importA1111(graph, parameters) {
const vaeLoaderNode = LiteGraph.createNode("VAELoader"); const vaeLoaderNode = LiteGraph.createNode("VAELoader");
const saveNode = LiteGraph.createNode("SaveImage"); const saveNode = LiteGraph.createNode("SaveImage");
let hrSamplerNode = null; let hrSamplerNode = null;
let hrSteps = null;
const ceil64 = (v) => Math.ceil(v / 64) * 64; const ceil64 = (v) => Math.ceil(v / 64) * 64;
@ -290,6 +294,9 @@ export async function importA1111(graph, parameters) {
model(v) { model(v) {
setWidgetValue(ckptNode, "ckpt_name", v, true); setWidgetValue(ckptNode, "ckpt_name", v, true);
}, },
"vae"(v) {
setWidgetValue(vaeLoaderNode, "vae_name", v, true);
},
"cfg scale"(v) { "cfg scale"(v) {
setWidgetValue(samplerNode, "cfg", +v); setWidgetValue(samplerNode, "cfg", +v);
}, },
@ -316,6 +323,7 @@ export async function importA1111(graph, parameters) {
const h = ceil64(+wxh[1]); const h = ceil64(+wxh[1]);
const hrUp = popOpt("hires upscale"); const hrUp = popOpt("hires upscale");
const hrSz = popOpt("hires resize"); const hrSz = popOpt("hires resize");
hrSteps = popOpt("hires steps");
let hrMethod = popOpt("hires upscaler"); let hrMethod = popOpt("hires upscaler");
setWidgetValue(imageNode, "width", w); setWidgetValue(imageNode, "width", w);
@ -398,7 +406,7 @@ export async function importA1111(graph, parameters) {
} }
if (hrSamplerNode) { if (hrSamplerNode) {
setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value); setWidgetValue(hrSamplerNode, "steps", hrSteps? +hrSteps : getWidget(samplerNode, "steps").value);
setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value); setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value);
setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value); setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value);
setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value); setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value);
@ -415,7 +423,7 @@ export async function importA1111(graph, parameters) {
graph.arrange(); graph.arrange();
for (const opt of ["model hash", "ensd"]) { for (const opt of ["model hash", "ensd", "version", "vae hash", "ti hashes", "lora hashes", "hashes"]) {
delete opts[opt]; delete opts[opt];
} }

View File

@ -90,12 +90,15 @@ function dragElement(dragEl, settings) {
}).observe(dragEl); }).observe(dragEl);
function ensureInBounds() { function ensureInBounds() {
if (dragEl.classList.contains("comfy-menu-manual-pos")) { try {
newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft)); newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft));
newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop)); newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop));
positionElement(); positionElement();
} }
catch(exception){
// robust
}
} }
function positionElement() { function positionElement() {

View File

@ -307,7 +307,9 @@ export const ComfyWidgets = {
return { widget: node.addWidget(widgetType, inputName, val, return { widget: node.addWidget(widgetType, inputName, val,
function (v) { function (v) {
if (config.round) { if (config.round) {
this.value = Math.round(v/config.round)*config.round; this.value = Math.round((v + Number.EPSILON)/config.round)*config.round;
if (this.value > config.max) this.value = config.max;
if (this.value < config.min) this.value = config.min;
} else { } else {
this.value = v; this.value = v;
} }