diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 763d8cc78..385a25164 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm from . import utils from . import deis import comfy.model_patcher +import comfy.model_sampling def append_zero(x): return torch.cat([x, x.new_zeros([1])]) @@ -509,6 +510,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac @torch.no_grad() def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST): + return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) + """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -541,6 +545,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, return x +@torch.no_grad() +def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 + lambda_fn = lambda sigma: ((1-sigma)/sigma).log() + + # logged_x = x.unsqueeze(0) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta + sigma_down = sigmas[i+1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i+1] + alpha_down = 1 - sigma_down + renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++(2S) + if sigmas[i] == 1.0: + sigma_s = 0.9999 + else: + t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down) + r = 1 / 2 + h = t_down - t_i + s = t_i + r * h + sigma_s = sigma_fn(s) + # sigma_s = sigmas[i+1] + sigma_s_i_ratio = sigma_s / sigmas[i] + u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised + D_i = model(u, sigma_s * s_in, **extra_args) + sigma_down_i_ratio = sigma_down / sigmas[i] + x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i + # print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff) + # Noise addition + if sigmas[i + 1] > 0 and eta > 0: + x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0) + return x + @torch.no_grad() def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" diff --git a/comfy/ldm/flux/controlnet_xlabs.py b/comfy/ldm/flux/controlnet_xlabs.py index 3f40021b2..5d700f16c 100644 --- a/comfy/ldm/flux/controlnet_xlabs.py +++ b/comfy/ldm/flux/controlnet_xlabs.py @@ -82,7 +82,7 @@ class ControlNetFlux(Flux): block_res_sample = controlnet_block(block_res_sample) controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) - return {"output": (controlnet_block_res_samples * 10)[:19]} + return {"input": (controlnet_block_res_samples * 10)[:19]} def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): hint = hint * 2.0 - 1.0 diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b5373540a..63970cad2 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -114,19 +114,28 @@ class Flux(nn.Module): ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - for i in range(len(self.double_blocks)): - img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe) + for i, block in enumerate(self.double_blocks): + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) - if control is not None: #Controlnet - control_o = control.get("output") - if i < len(control_o): - add = control_o[i] + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] if add is not None: img += add img = torch.cat((txt, img), 1) - for block in self.single_blocks: + + for i, block in enumerate(self.single_blocks): img = block(img, vec=vec, pe=pe) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, txt.shape[1] :, ...] += add + img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) diff --git a/comfy/sd.py b/comfy/sd.py index cae7812e3..49cef2f2f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -670,10 +670,13 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if clip is not None: load_models.append(clip.load_model()) clip_sd = clip.get_sd() + vae_sd = None + if vae is not None: + vae_sd = vae.get_sd() model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) + sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k] diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 136f9a984..8fb8bf799 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -333,6 +333,25 @@ class VAESave: comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) return {} +class ModelSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None): + save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + return {} + NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, "ModelMergeBlocks": ModelMergeBlocks, @@ -344,4 +363,5 @@ NODE_CLASS_MAPPINGS = { "CLIPMergeAdd": CLIPAdd, "CLIPSave": CLIPSave, "VAESave": VAESave, + "ModelSave": ModelSave, } diff --git a/folder_paths.py b/folder_paths.py index 3db1da61a..74a7d527c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -17,7 +17,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) -folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions) +folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) @@ -44,6 +44,10 @@ user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} +def map_legacy(folder_name: str) -> str: + legacy = {"unet": "diffusion_models"} + return legacy.get(folder_name, folder_name) + if not os.path.exists(input_directory): try: os.makedirs(input_directory) @@ -128,12 +132,14 @@ def exists_annotated_filepath(name) -> bool: def add_model_folder_path(folder_name: str, full_folder_path: str) -> None: global folder_names_and_paths + folder_name = map_legacy(folder_name) if folder_name in folder_names_and_paths: folder_names_and_paths[folder_name][0].append(full_folder_path) else: folder_names_and_paths[folder_name] = ([full_folder_path], set()) def get_folder_paths(folder_name: str) -> list[str]: + folder_name = map_legacy(folder_name) return folder_names_and_paths[folder_name][0][:] def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]: @@ -180,6 +186,7 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) def get_full_path(folder_name: str, filename: str) -> str | None: global folder_names_and_paths + folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: return None folders = folder_names_and_paths[folder_name] @@ -194,6 +201,7 @@ def get_full_path(folder_name: str, filename: str) -> str | None: return None def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: + folder_name = map_legacy(folder_name) global folder_names_and_paths output_list = set() folders = folder_names_and_paths[folder_name] @@ -208,6 +216,7 @@ def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], f def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None: global filename_list_cache global folder_names_and_paths + folder_name = map_legacy(folder_name) if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] @@ -227,6 +236,7 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float] return out def get_filename_list(folder_name: str) -> list[str]: + folder_name = map_legacy(folder_name) out = cached_filename_list_(folder_name) if out is None: out = get_filename_list_(folder_name) diff --git a/main.py b/main.py index e9d6ed209..9bd985149 100644 --- a/main.py +++ b/main.py @@ -242,6 +242,7 @@ if __name__ == "__main__": folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) + folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models")) if args.input_directory: input_dir = os.path.abspath(args.input_directory) diff --git a/models/diffusion_models/put_diffusion_model_files_here b/models/diffusion_models/put_diffusion_model_files_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index b817a865e..513cd0c7f 100644 --- a/nodes.py +++ b/nodes.py @@ -855,7 +855,7 @@ class ControlNetApplyAdvanced: class UNETLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), + return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],) }} RETURN_TYPES = ("MODEL",) @@ -870,7 +870,7 @@ class UNETLoader: elif weight_dtype == "fp8_e5m2": model_options["dtype"] = torch.float8_e5m2 - unet_path = folder_paths.get_full_path("unet", unet_name) + unet_path = folder_paths.get_full_path("diffusion_models", unet_name) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) return (model,)