diff --git a/comfy/controlnet.py b/comfy/controlnet.py index a6a5fadf6..94fd0a805 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -139,11 +139,12 @@ class ControlBase: return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): + def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model self.load_device = load_device - self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device()) + if control_model is not None: + self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling self.model_sampling_current = None self.manual_cast_dtype = manual_cast_dtype @@ -184,7 +185,9 @@ class ControlNet(ControlBase): return self.control_merge(None, control, control_prev, output_dtype) def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped self.copy_to(c) return c diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 08018c54d..ed2a45fea 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys())) # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp code2idx = {"q": 0, "k": 1, "v": 2} +# This function exists because at the time of writing torch.cat can't do fp8 with cuda +def cat_tensors(tensors): + x = 0 + for t in tensors: + x += t.shape[0] + + shape = [x] + list(tensors[0].shape)[1:] + out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) + + x = 0 + for t in tensors: + out[x:x + t.shape[0]] = t + x += t.shape[0] + + return out def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): new_state_dict = {} @@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): if None in tensors: raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) - new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors) for k_pre, tensors in capture_qkv_bias.items(): if None in tensors: raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) - new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors) return new_state_dict diff --git a/comfy/model_base.py b/comfy/model_base.py index d0e8193a2..48990c7e0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -10,6 +10,7 @@ from .ldm.cascade.stage_c import StageC from .ldm.cascade.stage_b import StageB from enum import Enum from . import utils +from . import latent_formats class ModelType(Enum): EPS = 1 @@ -66,7 +67,8 @@ class BaseModel(torch.nn.Module): self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 - self.inpaint_model = False + + self.concat_keys = () logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) @@ -107,8 +109,7 @@ class BaseModel(torch.nn.Module): def extra_conds(self, **kwargs): out = {} - if self.inpaint_model: - concat_keys = ("mask", "masked_image") + if len(self.concat_keys) > 0: cond_concat = [] denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) concat_latent_image = kwargs.get("concat_latent_image", None) @@ -125,24 +126,16 @@ class BaseModel(torch.nn.Module): concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) - if len(denoise_mask.shape) == len(noise.shape): - denoise_mask = denoise_mask[:,:1] + if denoise_mask is not None: + if len(denoise_mask.shape) == len(noise.shape): + denoise_mask = denoise_mask[:,:1] - denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) - if denoise_mask.shape[-2:] != noise.shape[-2:]: - denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") - denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) + denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + if denoise_mask.shape[-2:] != noise.shape[-2:]: + denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") + denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) - def blank_inpaint_image_like(latent_image): - blank_image = torch.ones_like(latent_image) - # these are the values for "zero" in pixel space translated to latent space - blank_image[:,0] *= 0.8223 - blank_image[:,1] *= -0.6876 - blank_image[:,2] *= 0.6364 - blank_image[:,3] *= 0.1380 - return blank_image - - for ck in concat_keys: + for ck in self.concat_keys: if denoise_mask is not None: if ck == "mask": cond_concat.append(denoise_mask.to(device)) @@ -152,7 +145,7 @@ class BaseModel(torch.nn.Module): if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) elif ck == "masked_image": - cond_concat.append(blank_inpaint_image_like(noise)) + cond_concat.append(self.blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = conds.CONDNoiseShape(data) adm = self.encode_adm(**kwargs) @@ -220,7 +213,16 @@ class BaseModel(torch.nn.Module): return unet_state_dict def set_inpaint(self): - self.inpaint_model = True + self.concat_keys = ("mask", "masked_image") + def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + self.blank_inpaint_image_like = blank_inpaint_image_like def memory_required(self, input_shape): if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention(): @@ -471,6 +473,42 @@ class SD_X4Upscaler(BaseModel): out['y'] = conds.CONDRegular(noise_level) return out +class IP2P: + def extra_conds(self, **kwargs): + out = {} + + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if image is None: + image = torch.zeros_like(noise) + + if image.shape[1:] != noise.shape[1:]: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + image = utils.resize_to_batch_size(image, noise.shape[0]) + + out['c_concat'] = conds.CONDNoiseShape(self.process_ip2p_image_in(image)) + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = conds.CONDRegular(adm) + return out + +class SD15_instructpix2pix(IP2P, BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) + self.process_ip2p_image_in = lambda image: image + +class SDXL_instructpix2pix(IP2P, SDXL): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) + if model_type == ModelType.V_PREDICTION_EDM: + self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) #cosxl ip2p + else: + self.process_ip2p_image_in = lambda image: image #diffusers ip2p + + class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageC) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 65fd41abd..940889bc3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix): return unet_config -def model_config_from_unet_config(unet_config): +def model_config_from_unet_config(unet_config, state_dict=None): for model_config in supported_models.models: - if model_config.matches(unet_config): + if model_config.matches(unet_config, state_dict): return model_config(unet_config) logging.error("no match {}".format(unet_config)) @@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): unet_config = detect_unet_config(state_dict, unet_key_prefix) - model_config = model_config_from_unet_config(unet_config) + model_config = model_config_from_unet_config(unet_config, state_dict) if model_config is None and use_base_if_no_match: return supported_models_base.BASE(unet_config) else: @@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], 'use_temporal_attention': False, 'use_temporal_resblock': False} + SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} + SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], @@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1], 'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS] + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p] for unet_config in supported_models: matches = True diff --git a/comfy/model_management.py b/comfy/model_management.py index 94849a7b7..57c8e3d06 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -396,6 +396,7 @@ def load_models_gpu(models, memory_required=0): inference_memory = minimum_inference_memory() extra_mem = max(inference_memory, memory_required) + models = set(models) models_to_load = [] models_already_loaded = [] for x in models: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index bc51743af..1572bfab3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -150,6 +150,15 @@ class ModelPatcher: def add_object_patch(self, name, obj): self.object_patches[name] = obj + def get_model_object(self, name): + if name in self.object_patches: + return self.object_patches[name] + else: + if name in self.object_patches_backup: + return self.object_patches_backup[name] + else: + return utils.get_attr(self.model, name) + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: @@ -278,7 +287,7 @@ class ModelPatcher: if weight_key in self.patches: m.weight_function = LowVramPatch(weight_key, self) if bias_key in self.patches: - m.bias_function = LowVramPatch(weight_key, self) + m.bias_function = LowVramPatch(bias_key, self) m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True @@ -462,4 +471,4 @@ class ModelPatcher: for k in keys: utils.set_attr(self.model, k, self.object_patches_backup[k]) - self.object_patches_backup = {} + self.object_patches_backup.clear() diff --git a/comfy/node_helpers.py b/comfy/node_helpers.py new file mode 100644 index 000000000..8828a4ec9 --- /dev/null +++ b/comfy/node_helpers.py @@ -0,0 +1,10 @@ + +def conditioning_set_values(conditioning, values={}): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + for k in values: + n[1][k] = values[k] + c.append(n) + + return c diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index c8b72a829..e69502708 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -30,7 +30,7 @@ from ..model_downloader import get_filename_list_with_downloadable, get_or_downl from ..nodes.common import MAX_RESOLUTION from .. import controlnet from ..open_exr import load_exr - +from .. import node_helpers class CLIPTextEncode: @classmethod @@ -140,13 +140,9 @@ class ConditioningSetArea: CATEGORY = "conditioning" def append(self, conditioning, width, height, x, y, strength): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) - n[1]['strength'] = strength - n[1]['set_area_to_bounds'] = False - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8), + "strength": strength, + "set_area_to_bounds": False}) return (c, ) class ConditioningSetAreaPercentage: @@ -165,13 +161,9 @@ class ConditioningSetAreaPercentage: CATEGORY = "conditioning" def append(self, conditioning, width, height, x, y, strength): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]['area'] = ("percentage", height, width, y, x) - n[1]['strength'] = strength - n[1]['set_area_to_bounds'] = False - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x), + "strength": strength, + "set_area_to_bounds": False}) return (c, ) class ConditioningSetAreaStrength: @@ -186,11 +178,7 @@ class ConditioningSetAreaStrength: CATEGORY = "conditioning" def append(self, conditioning, strength): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]['strength'] = strength - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"strength": strength}) return (c, ) @@ -208,19 +196,15 @@ class ConditioningSetMask: CATEGORY = "conditioning" def append(self, conditioning, mask, set_cond_area, strength): - c = [] set_area_to_bounds = False if set_cond_area != "default": set_area_to_bounds = True if len(mask.shape) < 3: mask = mask.unsqueeze(0) - for t in conditioning: - n = [t[0], t[1].copy()] - _, h, w = mask.shape - n[1]['mask'] = mask - n[1]['set_area_to_bounds'] = set_area_to_bounds - n[1]['mask_strength'] = strength - c.append(n) + + c = node_helpers.conditioning_set_values(conditioning, {"mask": mask, + "set_area_to_bounds": set_area_to_bounds, + "mask_strength": strength}) return (c, ) class ConditioningZeroOut: @@ -255,13 +239,8 @@ class ConditioningSetTimestepRange: CATEGORY = "advanced/conditioning" def set_range(self, conditioning, start, end): - c = [] - for t in conditioning: - d = t[1].copy() - d['start_percent'] = start - d['end_percent'] = end - n = [t[0], d] - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start, + "end_percent": end}) return (c, ) class VAEDecode: @@ -402,13 +381,8 @@ class InpaintModelConditioning: out = [] for conditioning in [positive, negative]: - c = [] - for t in conditioning: - d = t[1].copy() - d["concat_latent_image"] = concat_latent - d["concat_mask"] = mask - n = [t[0], d] - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent, + "concat_mask": mask}) out.append(c) return (out[0], out[1], out_latent) diff --git a/comfy/nodes/vanilla_node_importing.py b/comfy/nodes/vanilla_node_importing.py index 478b6ee58..b6e40b308 100644 --- a/comfy/nodes/vanilla_node_importing.py +++ b/comfy/nodes/vanilla_node_importing.py @@ -8,14 +8,14 @@ import sys import time import types from contextlib import contextmanager -from typing import Dict, List +from typing import Dict, List, Iterable from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath from . import base_nodes from .package_typing import ExportedNodes -def _vanilla_load_importing_execute_prestartup_script(node_paths: List[str]) -> None: +def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None: def execute_script(script_path): module_name = splitext(script_path)[0] try: @@ -121,7 +121,7 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes: return exported_nodes -def _vanilla_load_custom_nodes_2(node_paths: List[str]) -> ExportedNodes: +def _vanilla_load_custom_nodes_2(node_paths: Iterable[str]) -> ExportedNodes: base_node_names = set(base_nodes.NODE_CLASS_MAPPINGS.keys()) node_import_times = [] exported_nodes = ExportedNodes() @@ -192,6 +192,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes: if is_git_repository: node_paths += [abspath(join(potential_git_dir_parent, "custom_nodes"))] + node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths) + _vanilla_load_importing_execute_prestartup_script(node_paths) vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths) return vanilla_custom_nodes diff --git a/comfy/sample.py b/comfy/sample.py index 158607075..7a97e54be 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -5,6 +5,7 @@ from . import utils from . import conds import math import numpy as np +import logging def prepare_noise(latent_image, seed, noise_inds=None): """ @@ -25,94 +26,21 @@ def prepare_noise(latent_image, seed, noise_inds=None): noises = torch.cat(noises, axis=0) return noises -def prepare_mask(noise_mask, shape, device): - """ensures noise mask is of proper dimensions""" - noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") - noise_mask = torch.cat([noise_mask] * shape[1], dim=1) - noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0]) - noise_mask = noise_mask.to(device) - return noise_mask - -def get_models_from_cond(cond, model_type): - models = [] - for c in cond: - if model_type in c: - models += [c[model_type]] - return models - -def convert_cond(cond): - out = [] - for c in cond: - temp = c[1].copy() - model_conds = temp.get("model_conds", {}) - if c[0] is not None: - model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove - temp["cross_attn"] = c[0] - temp["model_conds"] = model_conds - out.append(temp) - return out - -def get_additional_models(positive, negative, dtype): - """loads additional models in positive and negative conditioning""" - control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) - - inference_memory = 0 - control_models = [] - for m in control_nets: - control_models += m.get_models() - inference_memory += m.inference_memory_requirements(dtype) - - gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") - gligen = [x[1] for x in gligen] - models = control_models + gligen - return models, inference_memory +def prepare_sampling(model, noise_shape, positive, negative, noise_mask): + logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed") + return model, positive, negative, noise_mask, [] def cleanup_additional_models(models): - """cleanup additional models that were loaded""" - for m in models: - if hasattr(m, 'cleanup'): - m.cleanup() - -def prepare_sampling(model, noise_shape, positive, negative, noise_mask): - device = model.load_device - positive = convert_cond(positive) - negative = convert_cond(negative) - - if noise_mask is not None: - noise_mask = prepare_mask(noise_mask, noise_shape, device) - - real_model = None - models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) - real_model = model.model - - return real_model, positive, negative, noise_mask, models - + logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed") def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): - real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) + sampler = samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - noise = noise.to(model.load_device) - latent_image = latent_image.to(model.load_device) - - sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(model_management.intermediate_device()) - - cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): - real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) - noise = noise.to(model.load_device) - latent_image = latent_image.to(model.load_device) - sigmas = sigmas.to(model.load_device) - - samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(model_management.intermediate_device()) - cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples - diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py new file mode 100644 index 000000000..2fedef2f1 --- /dev/null +++ b/comfy/sampler_helpers.py @@ -0,0 +1,77 @@ +import torch +from comfy import model_management +from comfy import utils +from comfy import conds + +def prepare_mask(noise_mask, shape, device): + """ensures noise mask is of proper dimensions""" + noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") + noise_mask = torch.cat([noise_mask] * shape[1], dim=1) + noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0]) + noise_mask = noise_mask.to(device) + return noise_mask + +def get_models_from_cond(cond, model_type): + models = [] + for c in cond: + if model_type in c: + models += [c[model_type]] + return models + +def convert_cond(cond): + out = [] + for c in cond: + temp = c[1].copy() + model_conds = temp.get("model_conds", {}) + if c[0] is not None: + model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove + temp["cross_attn"] = c[0] + temp["model_conds"] = model_conds + out.append(temp) + return out + +def get_additional_models(conds, dtype): + """loads additional models in conditioning""" + cnets = [] + gligen = [] + + for k in conds: + cnets += get_models_from_cond(conds[k], "control") + gligen += get_models_from_cond(conds[k], "gligen") + + control_nets = set(cnets) + + inference_memory = 0 + control_models = [] + for m in control_nets: + control_models += m.get_models() + inference_memory += m.inference_memory_requirements(dtype) + + gligen = [x[1] for x in gligen] + models = control_models + gligen + return models, inference_memory + +def cleanup_additional_models(models): + """cleanup additional models that were loaded""" + for m in models: + if hasattr(m, 'cleanup'): + m.cleanup() + + +def prepare_sampling(model, noise_shape, conds): + device = model.load_device + real_model = None + models, inference_memory = get_additional_models(conds, model.model_dtype()) + model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) + real_model = model.model + + return real_model, conds, models + +def cleanup_models(conds, models): + cleanup_additional_models(models) + + control_cleanup = [] + for k in conds: + control_cleanup += get_models_from_cond(conds[k], "control") + + cleanup_additional_models(set(control_cleanup)) diff --git a/comfy/samplers.py b/comfy/samplers.py index e5a42a8b3..2a81ef43e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -5,6 +5,7 @@ import collections from . import model_management import math import logging +from . import sampler_helpers from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES @@ -130,30 +131,23 @@ def cond_cat(c_list): return out -def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - +def calc_cond_batch(model, conds, x_in, timestep, model_options): + out_conds = [] + out_counts = [] to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) - to_run += [(p, UNCOND)] + cond = conds[i] + if cond is not None: + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, i)] while len(to_run) > 0: first = to_run[0] @@ -225,74 +219,66 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult + cond_index = cond_or_uncond[o] + out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + +def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove + logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") + return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) + +def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None): + if "sampler_cfg_function" in model_options: + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + cfg_result = x - model_options["sampler_cfg_function"](args) + else: + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) + + return cfg_result #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: - uncond_ = None - else: - uncond_ = uncond + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: + uncond_ = None + else: + uncond_ = uncond - cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) - if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} - cfg_result = x - model_options["sampler_cfg_function"](args) - else: - cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + conds = [cond, uncond_] + out = calc_cond_batch(model, conds, x, timestep, model_options) + return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) - for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, - "sigma": timestep, "model_options": model_options, "input": x} - cfg_result = fn(args) - return cfg_result - -class CFGNoisePredictor(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): - out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) - return out - def forward(self, *args, **kwargs): - return self.apply_model(*args, **kwargs) - -class KSamplerX0Inpaint(torch.nn.Module): +class KSamplerX0Inpaint: def __init__(self, model, sigmas): - super().__init__() self.inner_model = model self.sigmas = sigmas - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): + def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: if "denoise_mask_function" in model_options: denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) + out = self.inner_model(x, sigma, model_options=model_options, seed=seed) if denoise_mask is not None: out = out * denoise_mask + self.latent_image * latent_mask return out -def simple_scheduler(model, steps): - s = model.model_sampling +def simple_scheduler(model_sampling, steps): + s = model_sampling sigs = [] ss = len(s.sigmas) / steps for x in range(steps): @@ -300,8 +286,8 @@ def simple_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def ddim_scheduler(model, steps): - s = model.model_sampling +def ddim_scheduler(model_sampling, steps): + s = model_sampling sigs = [] ss = max(len(s.sigmas) // steps, 1) x = 1 @@ -312,8 +298,8 @@ def ddim_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def normal_scheduler(model, steps, sgm=False, floor=False): - s = model.model_sampling +def normal_scheduler(model_sampling, steps, sgm=False, floor=False): + s = model_sampling start = s.timestep(s.sigma_max) end = s.timestep(s.sigma_min) @@ -574,59 +560,120 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return KSAMPLER(sampler_function, extra_options, inpaint_options) -def wrap_model(model): - model_denoise = CFGNoisePredictor(model) - return model_denoise -def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - positive = positive[:] - negative = negative[:] +def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): + for k in conds: + conds[k] = conds[k][:] + resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device) - resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) - resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) - - model_wrap = wrap_model(model) - - calculate_start_end_timesteps(model, negative) - calculate_start_end_timesteps(model, positive) - - if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. - latent_image = model.process_latent_in(latent_image) + for k in conds: + calculate_start_end_timesteps(model, conds[k]) if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + for k in conds: + conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) #make sure each cond area has an opposite one with the same area - for c in positive: - create_cond_with_same_area_if_none(negative, c) - for c in negative: - create_cond_with_same_area_if_none(positive, c) + for k in conds: + for c in conds[k]: + for kk in conds: + if k != kk: + create_cond_with_same_area_if_none(conds[kk], c) - pre_run_control(model, negative + positive) + for k in conds: + pre_run_control(model, conds[k]) - apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) - apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) + if "positive" in conds: + positive = conds["positive"] + for k in conds: + if k != "positive": + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x]) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} + return conds - samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) - return model.process_latent_out(samples.to(torch.float32)) +class CFGGuider: + def __init__(self, model_patcher): + self.model_patcher = model_patcher + self.model_options = model_patcher.model_options + self.original_conds = {} + self.cfg = 1.0 + + def set_conds(self, positive, negative): + self.inner_set_conds({"positive": positive, "negative": negative}) + + def set_cfg(self, cfg): + self.cfg = cfg + + def inner_set_conds(self, conds): + for k in conds: + self.original_conds[k] = sampler_helpers.convert_cond(conds[k]) + + def __call__(self, *args, **kwargs): + return self.predict_noise(*args, **kwargs) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) + + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = self.inner_model.process_latent_in(latent_image) + + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + + extra_args = {"model_options": self.model_options, "seed":seed} + + samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) + return self.inner_model.process_latent_out(samples.to(torch.float32)) + + def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + if sigmas.shape[-1] == 0: + return latent_image + + self.conds = {} + for k in self.original_conds: + self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) + + self.inner_model, self.conds, self.loaded_models = sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + device = self.model_patcher.load_device + + if denoise_mask is not None: + denoise_mask = sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) + + noise = noise.to(device) + latent_image = latent_image.to(device) + sigmas = sigmas.to(device) + + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + + sampler_helpers.cleanup_models(self.conds, self.loaded_models) + del self.inner_model + del self.conds + del self.loaded_models + return output -def calculate_sigmas_scheduler(model, scheduler_name, steps): +def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + cfg_guider = CFGGuider(model) + cfg_guider.set_conds(positive, negative) + cfg_guider.set_cfg(cfg) + return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + + + +def calculate_sigmas(model_sampling, scheduler_name, steps): if scheduler_name == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "normal": - sigmas = normal_scheduler(model, steps) + sigmas = normal_scheduler(model_sampling, steps) elif scheduler_name == "simple": - sigmas = simple_scheduler(model, steps) + sigmas = simple_scheduler(model_sampling, steps) elif scheduler_name == "ddim_uniform": - sigmas = ddim_scheduler(model, steps) + sigmas = ddim_scheduler(model_sampling, steps) elif scheduler_name == "sgm_uniform": - sigmas = normal_scheduler(model, steps, sgm=True) + sigmas = normal_scheduler(model_sampling, steps, sgm=True) else: logging.error("error invalid scheduler {}".format(scheduler_name)) return sigmas @@ -668,7 +715,7 @@ class KSampler: steps += 1 discard_penultimate_sigma = True - sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) + sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps) if discard_penultimate_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) @@ -679,9 +726,12 @@ class KSampler: if denoise is None or denoise > 0.9999: self.sigmas = self.calculate_sigmas(steps).to(self.device) else: - new_steps = int(steps/denoise) - sigmas = self.calculate_sigmas(new_steps).to(self.device) - self.sigmas = sigmas[-(steps + 1):] + if denoise <= 0.0: + self.sigmas = torch.FloatTensor([]) + else: + new_steps = int(steps/denoise) + sigmas = self.calculate_sigmas(new_steps).to(self.device) + self.sigmas = sigmas[-(steps + 1):] def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): if sigmas is None: diff --git a/comfy/sd.py b/comfy/sd.py index 9889f4fe8..63b4d5126 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -600,7 +600,7 @@ def load_unet(unet_path): raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model -def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): +def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): clip_sd = None load_models = [model] if clip is not None: @@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) + for k in extra_keys: + sd[k] = extra_keys[k] + utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5b2eb73fd..b3b69e05b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -174,6 +174,11 @@ class SDXL(supported_models_base.BASE): self.sampling_settings["sigma_max"] = 80.0 self.sampling_settings["sigma_min"] = 0.002 return model_base.ModelType.EDM + elif "edm_vpred.sigma_max" in state_dict: + self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item()) + if "edm_vpred.sigma_min" in state_dict: + self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item()) + return model_base.ModelType.V_PREDICTION_EDM elif "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION else: @@ -334,6 +339,11 @@ class Stable_Zero123(supported_models_base.BASE): "num_head_channels": -1, } + required_keys = { + "cc_projection.weight": None, + "cc_projection.bias": None, + } + clip_vision_prefix = "cond_stage_model.model.visual." latent_format = latent_formats.SD15 @@ -439,6 +449,33 @@ class Stable_Cascade_B(Stable_Cascade_C): out = model_base.StableCascade_B(self, device=device) return out +class SD15_instructpix2pix(SD15): + unet_config = { + "context_dim": 768, + "model_channels": 320, + "use_linear_in_transformer": False, + "adm_in_channels": None, + "use_temporal_attention": False, + "in_channels": 8, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SD15_instructpix2pix(self, device=device) + +class SDXL_instructpix2pix(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 2, 2, 10, 10], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + "in_channels": 8, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 4d7e25936..6196daabf 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -16,6 +16,8 @@ class BASE: "num_head_channels": 64, } + required_keys = {} + clip_prefix = [] clip_vision_prefix = None noise_aug_config = None @@ -28,10 +30,14 @@ class BASE: manual_cast_dtype = None @classmethod - def matches(s, unet_config): + def matches(s, unet_config, state_dict=None): for k in s.unet_config: if k not in unet_config or s.unet_config[k] != unet_config[k]: return False + if state_dict is not None: + for k in s.required_keys: + if k not in state_dict: + return False return True def model_type(self, state_dict, prefix=""): diff --git a/comfy/web/lib/litegraph.core.js b/comfy/web/lib/litegraph.core.js index 4ff05ae81..427a62b59 100644 --- a/comfy/web/lib/litegraph.core.js +++ b/comfy/web/lib/litegraph.core.js @@ -7247,7 +7247,7 @@ LGraphNode.prototype.executeAction = function(action) //create links for (var i = 0; i < clipboard_info.links.length; ++i) { var link_info = clipboard_info.links[i]; - var origin_node; + var origin_node = undefined; var origin_node_relative_id = link_info[0]; if (origin_node_relative_id != null) { origin_node = nodes[origin_node_relative_id]; diff --git a/comfy/web/scripts/pnginfo.js b/comfy/web/scripts/pnginfo.js index 169609209..7132fb60f 100644 --- a/comfy/web/scripts/pnginfo.js +++ b/comfy/web/scripts/pnginfo.js @@ -170,9 +170,12 @@ export async function importA1111(graph, parameters) { const opts = parameters .substr(p) .split("\n")[1] - .split(",") + .match(new RegExp("\\s*([^:]+:\\s*([^\"\\{].*?|\".*?\"|\\{.*?\\}))\\s*(,|$)", "g")) .reduce((p, n) => { const s = n.split(":"); + if (s[1].endsWith(',')) { + s[1] = s[1].substr(0, s[1].length -1); + } p[s[0].trim().toLowerCase()] = s[1].trim(); return p; }, {}); @@ -191,6 +194,7 @@ export async function importA1111(graph, parameters) { const vaeLoaderNode = LiteGraph.createNode("VAELoader"); const saveNode = LiteGraph.createNode("SaveImage"); let hrSamplerNode = null; + let hrSteps = null; const ceil64 = (v) => Math.ceil(v / 64) * 64; @@ -290,6 +294,9 @@ export async function importA1111(graph, parameters) { model(v) { setWidgetValue(ckptNode, "ckpt_name", v, true); }, + "vae"(v) { + setWidgetValue(vaeLoaderNode, "vae_name", v, true); + }, "cfg scale"(v) { setWidgetValue(samplerNode, "cfg", +v); }, @@ -316,6 +323,7 @@ export async function importA1111(graph, parameters) { const h = ceil64(+wxh[1]); const hrUp = popOpt("hires upscale"); const hrSz = popOpt("hires resize"); + hrSteps = popOpt("hires steps"); let hrMethod = popOpt("hires upscaler"); setWidgetValue(imageNode, "width", w); @@ -398,7 +406,7 @@ export async function importA1111(graph, parameters) { } if (hrSamplerNode) { - setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value); + setWidgetValue(hrSamplerNode, "steps", hrSteps? +hrSteps : getWidget(samplerNode, "steps").value); setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value); setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value); setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value); @@ -415,7 +423,7 @@ export async function importA1111(graph, parameters) { graph.arrange(); - for (const opt of ["model hash", "ensd"]) { + for (const opt of ["model hash", "ensd", "version", "vae hash", "ti hashes", "lora hashes", "hashes"]) { delete opts[opt]; } diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index 923e40b71..4a11f8b14 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -6,6 +6,7 @@ from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.cmd import latent_preview import torch from comfy import utils +from comfy import node_helpers class BasicScheduler: @@ -26,10 +27,11 @@ class BasicScheduler: def get_sigmas(self, model, scheduler, steps, denoise): total_steps = steps if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) total_steps = int(steps/denoise) - model_management.load_models_gpu([model]) - sigmas = samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() + sigmas = samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] return (sigmas, ) @@ -162,6 +164,9 @@ class FlipSigmas: FUNCTION = "get_sigmas" def get_sigmas(self, sigmas): + if len(sigmas) == 0: + return (sigmas,) + sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 @@ -334,6 +339,24 @@ class SamplerDPMAdaptative: "s_noise":s_noise }) return (sampler, ) +class Noise_EmptyNoise: + def __init__(self): + self.seed = 0 + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + + +class Noise_RandomNoise: + def __init__(self, seed): + self.seed = seed + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None + return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -361,10 +384,9 @@ class SamplerCustom: latent = latent_image latent_image = latent["samples"] if not add_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + noise = Noise_EmptyNoise().generate_noise(latent) else: - batch_inds = latent["batch_index"] if "batch_index" in latent else None - noise = sample.prepare_noise(latent_image, noise_seed, batch_inds) + noise = Noise_RandomNoise(noise_seed).generate_noise(latent) noise_mask = None if "noise_mask" in latent: @@ -385,6 +407,161 @@ class SamplerCustom: out_denoised = out return (out, out_denoised) +class Guider_Basic(comfy.samplers.CFGGuider): + def set_conds(self, positive): + self.inner_set_conds({"positive": positive}) + +class BasicGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "conditioning": ("CONDITIONING", ), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, conditioning): + guider = Guider_Basic(model) + guider.set_conds(conditioning) + return (guider,) + +class CFGGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, positive, negative, cfg): + guider = comfy.samplers.CFGGuider(model) + guider.set_conds(positive, negative) + guider.set_cfg(cfg) + return (guider,) + +class Guider_DualCFG(comfy.samplers.CFGGuider): + def set_cfg(self, cfg1, cfg2): + self.cfg1 = cfg1 + self.cfg2 = cfg2 + + def set_conds(self, positive, middle, negative): + middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative}) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + negative_cond = self.conds.get("negative", None) + middle_cond = self.conds.get("middle", None) + + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options) + return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 + +class DualCFGGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "cond1": ("CONDITIONING", ), + "cond2": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative): + guider = Guider_DualCFG(model) + guider.set_conds(cond1, cond2, negative) + guider.set_cfg(cfg_conds, cfg_cond2_negative) + return (guider,) + +class DisableNoise: + @classmethod + def INPUT_TYPES(s): + return {"required":{ + } + } + + RETURN_TYPES = ("NOISE",) + FUNCTION = "get_noise" + CATEGORY = "sampling/custom_sampling/noise" + + def get_noise(self): + return (Noise_EmptyNoise(),) + + +class RandomNoise(DisableNoise): + @classmethod + def INPUT_TYPES(s): + return {"required":{ + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + } + } + + def get_noise(self, noise_seed): + return (Noise_RandomNoise(noise_seed),) + + +class SamplerCustomAdvanced: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"noise": ("NOISE", ), + "guider": ("GUIDER", ), + "sampler": ("SAMPLER", ), + "sigmas": ("SIGMAS", ), + "latent_image": ("LATENT", ), + } + } + + RETURN_TYPES = ("LATENT","LATENT") + RETURN_NAMES = ("output", "denoised_output") + + FUNCTION = "sample" + + CATEGORY = "sampling/custom_sampling" + + def sample(self, noise, guider, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) + samples = samples.to(comfy.model_management.intermediate_device()) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return (out, out_denoised) + NODE_CLASS_MAPPINGS = { "SamplerCustom": SamplerCustom, "BasicScheduler": BasicScheduler, @@ -402,4 +579,11 @@ NODE_CLASS_MAPPINGS = { "SamplerDPMAdaptative": SamplerDPMAdaptative, "SplitSigmas": SplitSigmas, "FlipSigmas": FlipSigmas, + + "CFGGuider": CFGGuider, + "DualCFGGuider": DualCFGGuider, + "BasicGuider": BasicGuider, + "RandomNoise": RandomNoise, + "DisableNoise": DisableNoise, + "SamplerCustomAdvanced": SamplerCustomAdvanced, } diff --git a/comfy_extras/nodes/nodes_ip2p.py b/comfy_extras/nodes/nodes_ip2p.py new file mode 100644 index 000000000..c2e70a84c --- /dev/null +++ b/comfy_extras/nodes/nodes_ip2p.py @@ -0,0 +1,45 @@ +import torch + +class InstructPixToPixConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "pixels": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/instructpix2pix" + + def encode(self, positive, negative, pixels, vae): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 + + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + + concat_latent = vae.encode(pixels) + + out_latent = {} + out_latent["samples"] = torch.zeros_like(concat_latent) + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1], out_latent) + +NODE_CLASS_MAPPINGS = { + "InstructPixToPixConditioning": InstructPixToPixConditioning, +} diff --git a/comfy_extras/nodes/nodes_model_merging.py b/comfy_extras/nodes/nodes_model_merging.py index 2a1363047..7a5532993 100644 --- a/comfy_extras/nodes/nodes_model_merging.py +++ b/comfy_extras/nodes/nodes_model_merging.py @@ -1,8 +1,10 @@ from comfy import sd, utils from comfy import model_base from comfy import model_management - +from comfy import model_sampling from comfy.cmd import folder_paths + +import torch import json import os @@ -188,6 +190,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", # "v2-inpainting" + extra_keys = {} + _model_sampling = model.get_model_object("model_sampling") + if isinstance(_model_sampling, model_sampling.ModelSamplingContinuousEDM): + if isinstance(_model_sampling, model_sampling.V_PREDICTION): + extra_keys["edm_vpred.sigma_max"] = torch.tensor(_model_sampling.sigma_max).float() + extra_keys["edm_vpred.sigma_min"] = torch.tensor(_model_sampling.sigma_min).float() + if model.model.model_type == model_base.ModelType.EPS: metadata["modelspec.predict_key"] = "epsilon" elif model.model.model_type == model_base.ModelType.V_PREDICTION: @@ -202,7 +211,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) + sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) class CheckpointSave: def __init__(self): diff --git a/comfy_extras/nodes/nodes_perpneg.py b/comfy_extras/nodes/nodes_perpneg.py index 43405a2ab..3e2f75146 100644 --- a/comfy_extras/nodes/nodes_perpneg.py +++ b/comfy_extras/nodes/nodes_perpneg.py @@ -1,8 +1,10 @@ import torch from comfy import sample from comfy import samplers +from comfy import sampler_helpers +#TODO: This node should be removed and replaced with one that uses the new Guider/SamplerCustomAdvanced. class PerpNeg: @classmethod def INPUT_TYPES(s): @@ -17,7 +19,7 @@ class PerpNeg: def patch(self, model, empty_conditioning, neg_scale): m = model.clone() - nocond = sample.convert_cond(empty_conditioning) + nocond = sampler_helpers.convert_cond(empty_conditioning) def cfg_function(args): model = args["model"] @@ -29,7 +31,7 @@ class PerpNeg: model_options = args["model_options"] nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") - (noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) + (noise_pred_nocond,) = samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options) pos = noise_pred_pos - noise_pred_nocond neg = noise_pred_neg - noise_pred_nocond diff --git a/comfy_extras/nodes/nodes_sag.py b/comfy_extras/nodes/nodes_sag.py index e520e13e4..1d6dd40cd 100644 --- a/comfy_extras/nodes/nodes_sag.py +++ b/comfy_extras/nodes/nodes_sag.py @@ -150,7 +150,7 @@ class SelfAttentionGuidance: degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded_noised = degraded + x - uncond_pred # call into the UNet - (sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) + (sag,) = samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options) return cfg_result + (degraded - sag) * sag_scale m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)