diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 937c5a388..dd6f7bbe5 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) return mu - def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) +@torch.no_grad() +def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + x = denoised + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + return x diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 55800e86e..6d7a61c41 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -6,7 +6,7 @@ import comfy.utils import comfy.model_management class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0, current_device=None): + def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): self.size = size self.model = model self.patches = {} @@ -22,6 +22,8 @@ class ModelPatcher: else: self.current_device = current_device + self.weight_inplace_update = weight_inplace_update + def model_size(self): if self.size > 0: return self.size @@ -134,6 +136,7 @@ class ModelPatcher: return list(p) def get_key_patches(self, filter_prefix=None): + comfy.model_management.unload_model_clones(self) model_sd = self.model_state_dict() p = {} for k in model_sd: @@ -170,15 +173,20 @@ class ModelPatcher: weight = model_sd[key] + inplace_update = self.weight_inplace_update + if key not in self.backup: - self.backup[key] = weight.to(self.offload_device) + self.backup[key] = weight.to(device=device_to, copy=inplace_update) if device_to is not None: temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - comfy.utils.set_attr(self.model, key, out_weight) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr(self.model, key, out_weight) del temp_weight if device_to is not None: @@ -294,8 +302,12 @@ class ModelPatcher: def unpatch_model(self, device_to=None): keys = list(self.backup.keys()) - for k in keys: - comfy.utils.set_attr(self.model, k, self.backup[k]) + if self.weight_inplace_update: + for k in keys: + comfy.utils.copy_to_param(self.model, k, self.backup[k]) + else: + for k in keys: + comfy.utils.set_attr(self.model, k, self.backup[k]) self.backup = {} diff --git a/comfy/ops.py b/comfy/ops.py index 610d54584..0bfb698aa 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,29 +1,23 @@ import torch from contextlib import contextmanager -class Linear(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) - if bias: - self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter('bias', None) - - def forward(self, input): - return torch.nn.functional.linear(input, self.weight, self.bias) +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None class Conv2d(torch.nn.Conv2d): def reset_parameters(self): return None +class Conv3d(torch.nn.Conv3d): + def reset_parameters(self): + return None + def conv_nd(dims, *args, **kwargs): if dims == 2: return Conv2d(*args, **kwargs) + elif dims == 3: + return Conv3d(*args, **kwargs) else: raise ValueError(f"unsupported dimensions: {dims}") diff --git a/comfy/samplers.py b/comfy/samplers.py index 964febb26..a839ee9e2 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options) if "sampler_cfg_function" in model_options: - args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x} + args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} return x - model_options["sampler_cfg_function"](args) else: return uncond + (cond - uncond) * cond_scale @@ -519,7 +519,7 @@ class UNIPCBH2(Sampler): KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] def ksampler(sampler_name, extra_options={}, inpaint_options={}): class KSAMPLER(Sampler): diff --git a/comfy/utils.py b/comfy/utils.py index 6a0c54e80..4b484d07a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -261,6 +261,14 @@ def set_attr(obj, attr, value): setattr(obj, attrs[-1], torch.nn.Parameter(value)) del prev +def copy_to_param(obj, attr, value): + # inplace update tensor instead of replacing it + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + prev.data.copy_(value) + def get_attr(obj, attr): attrs = attr.split(".") for name in attrs: diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index b52ad8fbd..154ecd0d2 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -188,7 +188,7 @@ class SamplerCustom: {"model": ("MODEL",), "add_noise": ("BOOLEAN", {"default": True}), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "sampler": ("SAMPLER", ), diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index c02cfb05a..399123eaa 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -1,6 +1,72 @@ import folder_paths import comfy.sd import comfy.model_sampling +import torch + +class LCM(comfy.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + x0 = model_input - model_output * sigma + + sigma_data = 0.5 + scaled_timestep = timestep * 10.0 #timestep_scaling + + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + + return c_out * x0 + c_skip * model_input + +class ModelSamplingDiscreteLCM(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigma_data = 1.0 + timesteps = 1000 + beta_start = 0.00085 + beta_end = 0.012 + + betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + original_timesteps = 50 + self.skip_steps = timesteps // original_timesteps + + + alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32) + for x in range(original_timesteps): + alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + + sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 + self.set_sigmas(sigmas) + + def set_sigmas(self, sigmas): + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) + + def sigma(self, timestep): + t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp() + + def percent_to_sigma(self, percent): + return self.sigma(torch.tensor(percent * 999.0)) def rescale_zero_terminal_snr_sigmas(sigmas): @@ -26,7 +92,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction"],), + "sampling": (["eps", "v_prediction", "lcm"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -38,20 +104,65 @@ class ModelSamplingDiscrete: def patch(self, model, sampling, zsnr): m = model.clone() + sampling_base = comfy.model_sampling.ModelSamplingDiscrete if sampling == "eps": sampling_type = comfy.model_sampling.EPS elif sampling == "v_prediction": sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "lcm": + sampling_type = LCM + sampling_base = ModelSamplingDiscreteLCM - class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type): + class ModelSamplingAdvanced(sampling_base, sampling_type): pass model_sampling = ModelSamplingAdvanced() if zsnr: model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) + m.add_object_patch("model_sampling", model_sampling) return (m, ) +class RescaleCFG: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, multiplier): + def rescale_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + sigma = args["sigma"] + sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) + x_orig = args["input"] + + #rescale cfg has to be done on v-pred model output + x = x_orig / (sigma * sigma + 1.0) + cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + + #rescalecfg + x_cfg = uncond + cond_scale * (cond - uncond) + ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True) + ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True) + + x_rescaled = x_cfg * (ro_pos / ro_cfg) + x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg + + return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5) + + m = model.clone() + m.set_model_sampler_cfg_function(rescale_cfg) + return (m, ) + NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, + "RescaleCFG": RescaleCFG, } diff --git a/nodes.py b/nodes.py index 5ed015442..2bbfd8fe8 100644 --- a/nodes.py +++ b/nodes.py @@ -1218,7 +1218,7 @@ class KSampler: {"model": ("MODEL",), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ), @@ -1244,7 +1244,7 @@ class KSamplerAdvanced: "add_noise": (["enable", "disable"], ), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ),