From e7b8e240f76d2afe0eed6db7a35833923792ae94 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Mar 2024 04:34:34 -0400 Subject: [PATCH 01/19] Add SamplerLMS node. --- comfy_extras/nodes_custom_sampler.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 0ad1246a6..f8e5f9752 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -245,6 +245,22 @@ class SamplerEulerAncestral: sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) return (sampler, ) +class SamplerLMS: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 4, "min": 1, "max": 100}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order): + sampler = comfy.samplers.ksampler("lms", {"order": order}) + return (sampler, ) + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -306,6 +322,7 @@ NODE_CLASS_MAPPINGS = { "SDTurboScheduler": SDTurboScheduler, "KSamplerSelect": KSamplerSelect, "SamplerEulerAncestral": SamplerEulerAncestral, + "SamplerLMS": SamplerLMS, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SplitSigmas": SplitSigmas, From eda87043862f743b0a0467735f8531f7c4709b3a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Mar 2024 12:16:37 -0400 Subject: [PATCH 02/19] Add SamplerDPMPP_3M_SDE node. --- comfy_extras/nodes_custom_sampler.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index f8e5f9752..8b808ce0b 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -181,6 +181,28 @@ class KSamplerSelect: sampler = comfy.samplers.sampler_object(sampler_name) return (sampler, ) +class SamplerDPMPP_3M_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + class SamplerDPMPP_2M_SDE: @classmethod def INPUT_TYPES(s): @@ -323,6 +345,7 @@ NODE_CLASS_MAPPINGS = { "KSamplerSelect": KSamplerSelect, "SamplerEulerAncestral": SamplerEulerAncestral, "SamplerLMS": SamplerLMS, + "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SplitSigmas": SplitSigmas, From db8b59ecff7be40377d17ea69487f442b469c536 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Mar 2024 19:04:41 -0400 Subject: [PATCH 03/19] Lower memory usage for loras in lowvram mode at the cost of perf. --- comfy/model_management.py | 36 +++------------- comfy/model_patcher.py | 91 +++++++++++++++++++++++++++++++-------- comfy/ops.py | 22 ++++++++++ 3 files changed, 101 insertions(+), 48 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2f0a0a627..66fa918bb 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -272,7 +272,6 @@ def module_size(module): class LoadedModel: def __init__(self, model): self.model = model - self.model_accelerated = False self.device = model.load_device def model_memory(self): @@ -285,52 +284,27 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0): - patch_model_to = None - if lowvram_model_memory == 0: - patch_model_to = self.device + patch_model_to = self.device self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) try: - self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU + if lowvram_model_memory > 0: + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) + else: + self.real_model = self.model.patch_model(device_to=patch_model_to) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() raise e - if lowvram_model_memory > 0: - logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) - mem_counter = 0 - for m in self.real_model.modules(): - if hasattr(m, "comfy_cast_weights"): - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - module_mem = module_size(m) - if mem_counter + module_mem < lowvram_model_memory: - m.to(self.device) - mem_counter += module_mem - elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode - m.to(self.device) - mem_counter += module_size(m) - logging.warning("lowvram: loaded module regularly {}".format(m)) - - self.model_accelerated = True - if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) return self.real_model def model_unload(self): - if self.model_accelerated: - for m in self.real_model.modules(): - if hasattr(m, "prev_comfy_cast_weights"): - m.comfy_cast_weights = m.prev_comfy_cast_weights - del m.prev_comfy_cast_weights - - self.model_accelerated = False - self.model.unpatch_model(self.model.offload_device) self.model.model_patches_to(self.model.offload_device) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5e578dffc..475fa812c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ class ModelPatcher: self.current_device = current_device self.weight_inplace_update = weight_inplace_update + self.model_lowvram = False def model_size(self): if self.size > 0: @@ -178,6 +179,27 @@ class ModelPatcher: sd.pop(k) return sd + def patch_weight_to_device(self, key, device_to=None): + if key not in self.patches: + return + + weight = comfy.utils.get_attr(self.model, key) + + inplace_update = self.weight_inplace_update + + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr_param(self.model, key, out_weight) + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) @@ -191,23 +213,7 @@ class ModelPatcher: logging.warning("could not patch. key doesn't exist in model: {}".format(key)) continue - weight = model_sd[key] - - inplace_update = self.weight_inplace_update - - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) - - if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr_param(self.model, key, out_weight) - del temp_weight + self.patch_weight_to_device(key, device_to) if device_to is not None: self.model.to(device_to) @@ -215,6 +221,47 @@ class ModelPatcher: return self.model + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): + self.patch_model(device_to, patch_weights=False) + + logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) + class LowVramPatch: + def __init__(self, key, model_patcher): + self.key = key + self.model_patcher = model_patcher + def __call__(self, weight): + return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) + + mem_counter = 0 + for n, m in self.model.named_modules(): + lowvram_weight = False + if hasattr(m, "comfy_cast_weights"): + module_mem = comfy.model_management.module_size(m) + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + + if lowvram_weight: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self) + if bias_key in self.patches: + m.bias_function = LowVramPatch(weight_key, self) + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + else: + if hasattr(m, "weight"): + self.patch_weight_to_device(weight_key, device_to) + self.patch_weight_to_device(bias_key, device_to) + m.to(device_to) + mem_counter += comfy.model_management.module_size(m) + logging.debug("lowvram: loaded module regularly {}".format(m)) + + self.model_lowvram = True + return self.model + def calculate_weight(self, patches, weight, key): for p in patches: alpha = p[0] @@ -341,6 +388,16 @@ class ModelPatcher: return weight def unpatch_model(self, device_to=None): + if self.model_lowvram: + for m in self.model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None + + self.model_lowvram = False + keys = list(self.backup.keys()) if self.weight_inplace_update: diff --git a/comfy/ops.py b/comfy/ops.py index 517688e8b..cfdec355c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,13 +24,20 @@ def cast_bias_weight(s, input): non_blocking = comfy.model_management.device_supports_non_blocking(input.device) if s.bias is not None: bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.bias_function is not None: + bias = s.bias_function(bias) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.weight_function is not None: + weight = s.weight_function(weight) return weight, bias class disable_weight_init: class Linear(torch.nn.Linear): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -46,6 +53,9 @@ class disable_weight_init: class Conv2d(torch.nn.Conv2d): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -61,6 +71,9 @@ class disable_weight_init: class Conv3d(torch.nn.Conv3d): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -76,6 +89,9 @@ class disable_weight_init: class GroupNorm(torch.nn.GroupNorm): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -92,6 +108,9 @@ class disable_weight_init: class LayerNorm(torch.nn.LayerNorm): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None @@ -111,6 +130,9 @@ class disable_weight_init: class ConvTranspose2d(torch.nn.ConvTranspose2d): comfy_cast_weights = False + weight_function = None + bias_function = None + def reset_parameters(self): return None From 448d9263a258062344e25135fc49d26a7e60887a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 14 Mar 2024 09:30:21 -0400 Subject: [PATCH 04/19] Fix control loras breaking. --- comfy/controlnet.py | 4 ++-- comfy/ops.py | 40 ++++++++++------------------------------ 2 files changed, 12 insertions(+), 32 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 1a72412b1..b6941d8c4 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -201,7 +201,7 @@ class ControlNet(ControlBase): super().cleanup() class ControlLoraOps: - class Linear(torch.nn.Module): + class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -220,7 +220,7 @@ class ControlLoraOps: else: return torch.nn.functional.linear(input, weight, bias) - class Conv2d(torch.nn.Module): + class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__( self, in_channels, diff --git a/comfy/ops.py b/comfy/ops.py index cfdec355c..eb6507682 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -31,13 +31,13 @@ def cast_bias_weight(s, input): weight = s.weight_function(weight) return weight, bias +class CastWeightBiasOp: + comfy_cast_weights = False + weight_function = None + bias_function = None class disable_weight_init: - class Linear(torch.nn.Linear): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class Linear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None @@ -51,11 +51,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class Conv2d(torch.nn.Conv2d): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): def reset_parameters(self): return None @@ -69,11 +65,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class Conv3d(torch.nn.Conv3d): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): def reset_parameters(self): return None @@ -87,11 +79,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class GroupNorm(torch.nn.GroupNorm): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -106,11 +94,7 @@ class disable_weight_init: return super().forward(*args, **kwargs) - class LayerNorm(torch.nn.LayerNorm): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -128,11 +112,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class ConvTranspose2d(torch.nn.ConvTranspose2d): - comfy_cast_weights = False - weight_function = None - bias_function = None - + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None From f2fe635c9f56a8e78866f59b3f110585e75b42f4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Mar 2024 19:34:22 -0400 Subject: [PATCH 05/19] SamplerDPMAdaptative node to test the different options. --- comfy/samplers.py | 4 ++-- comfy_extras/nodes_custom_sampler.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 16b4514e1..d721cb2e5 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -559,11 +559,11 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) sampler_function = dpm_fast_function elif sampler_name == "dpm_adaptive": - def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): + def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options): sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] - return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) + return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options) sampler_function = dpm_adaptive_function else: sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 8b808ce0b..72ff7957f 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -283,6 +283,33 @@ class SamplerLMS: sampler = comfy.samplers.ksampler("lms", {"order": order}) return (sampler, ) +class SamplerDPMAdaptative: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 3, "min": 2, "max": 3}), + "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) + return (sampler, ) + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -348,6 +375,7 @@ NODE_CLASS_MAPPINGS = { "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, + "SamplerDPMAdaptative": SamplerDPMAdaptative, "SplitSigmas": SplitSigmas, "FlipSigmas": FlipSigmas, } From d7897fff2cfd38eac051fbc958a6f944bdf68cc9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 16 Mar 2024 14:49:35 -0400 Subject: [PATCH 06/19] Move cascade scale factor from stage_a to latent_formats.py --- comfy/latent_formats.py | 2 +- comfy/ldm/cascade/stage_a.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 674364e72..4ca466d9a 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -95,7 +95,7 @@ class SC_Prior(LatentFormat): class SC_B(LatentFormat): def __init__(self): - self.scale_factor = 1.0 + self.scale_factor = 1.0 / 0.43 self.latent_rgb_factors = [ [ 0.1121, 0.2006, 0.1023], [-0.2093, -0.0222, -0.0195], diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 260ccfc0b..ca8867eaf 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -163,11 +163,9 @@ class ResBlock(nn.Module): class StageA(nn.Module): - def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, - scale_factor=0.43): # 0.3764 + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192): super().__init__() self.c_latent = c_latent - self.scale_factor = scale_factor c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] # Encoder blocks @@ -214,12 +212,11 @@ class StageA(nn.Module): x = self.down_blocks(x) if quantize: qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) - return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + return qe, x, indices, vq_loss + commit_loss * 0.25 else: - return x / self.scale_factor + return x def decode(self, x): - x = x * self.scale_factor x = self.up_blocks(x) x = self.out_block(x) return x From d3406d8d588d6b0c2da44c9bf378733a2077a14f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Mar 2024 08:57:49 -0400 Subject: [PATCH 07/19] Increase image batch nodes maximum values. --- comfy_extras/nodes_images.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 8f638bf8f..af37666b2 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -37,7 +37,7 @@ class RepeatImageBatch: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "repeat" @@ -52,8 +52,8 @@ class ImageFromBatch: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), - "length": ("INT", {"default": 1, "min": 1, "max": 64}), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), + "length": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "frombatch" From cacb022c4a5b9614f96086a866c8a4c4e9e85760 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Mar 2024 00:26:23 -0400 Subject: [PATCH 08/19] Make saved SD1 checkpoints match more closely the official one. --- comfy/supported_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 375821032..e12935a27 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -45,6 +45,11 @@ class SD15(supported_models_base.BASE): return state_dict def process_clip_state_dict_for_saving(self, state_dict): + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict: + state_dict.pop(p) + replace_prefix = {"clip_l.": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) From b1a16d4500b89f4c7db2aadf41acb378efd20948 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Mar 2024 13:50:11 -0400 Subject: [PATCH 09/19] Fix stable cascade img2img not working with all resolutions. --- comfy_extras/nodes_stable_cascade.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 7e2d37d22..fcbbeb27f 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -74,7 +74,7 @@ class StableCascade_StageC_VAEEncode: s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) c_latent = vae.encode(s[:,:,:,:3]) - b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4]) + b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) return ({ "samples": c_latent, }, { From 0b78213bdaaa9021daa870c544be6ce86f54d30d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Mar 2024 15:51:23 -0400 Subject: [PATCH 10/19] Fix neg scale step. --- comfy_extras/nodes_perpneg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index 64bbc1dcd..dc73c5528 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -10,7 +10,7 @@ class PerpNeg: def INPUT_TYPES(s): return {"required": {"model": ("MODEL", ), "empty_conditioning": ("CONDITIONING", ), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" From 40e124c6be01195eada95e8c319ca6ddf4fd1a17 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Mar 2024 10:04:51 -0400 Subject: [PATCH 11/19] SV3D support. --- comfy/model_base.py | 30 +++++++++++++++++ comfy/supported_models.py | 37 ++++++++++++++++++++- comfy_extras/nodes_stable3d.py | 53 +++++++++++++++++++++++++++---- comfy_extras/nodes_video_model.py | 28 ++++++++++++++++ 4 files changed, 141 insertions(+), 7 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 5da71e632..bc019de53 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -380,6 +380,36 @@ class SVD_img2vid(BaseModel): out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) return out +class SV3D_u(SVD_img2vid): + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + +class SV3D_p(SVD_img2vid): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder_512 = Timestep(512) + + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here + azimuth = kwargs.get("azimuth", 0) + noise = kwargs.get("noise", None) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0)))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0)))) + + out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) + return torch.cat(out, dim=1) + + class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index e12935a27..2ce9736b7 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -284,6 +284,41 @@ class SVD_img2vid(supported_models_base.BASE): def clip_target(self): return None +class SV3D_u(SVD_img2vid): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 256, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + vae_key_prefix = ["conditioner.embedders.1.encoder."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_u(self, device=device) + return out + +class SV3D_p(SV3D_u): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 1280, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_p(self, device=device) + return out + class Stable_Zero123(supported_models_base.BASE): unet_config = { "context_dim": 768, @@ -405,5 +440,5 @@ class Stable_Cascade_B(Stable_Cascade_C): return out -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index 4375d8f96..be2e34c28 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -29,8 +29,8 @@ class StableZero123_Conditioning: "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -62,10 +62,10 @@ class StableZero123_Conditioning_Batched: "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -95,8 +95,49 @@ class StableZero123_Conditioning_Batched: latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) +class SV3D_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + azimuth = 0 + azimuth_increment = 360 / (max(video_frames, 2) - 1) + + elevations = [] + azimuths = [] + for i in range(video_frames): + elevations.append(elevation) + azimuths.append(azimuth) + azimuth += azimuth_increment + + positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + NODE_CLASS_MAPPINGS = { "StableZero123_Conditioning": StableZero123_Conditioning, "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, + "SV3D_Conditioning": SV3D_Conditioning, } diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index a52625652..1a0189ed4 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -79,6 +79,33 @@ class VideoLinearCFGGuidance: m.set_model_sampler_cfg_function(linear_cfg) return (m, ) +class VideoTriangleCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + period = 1.0 + values = torch.linspace(0, 1, cond.shape[0], device=cond.device) + values = 2 * (values / period - torch.floor(values / period + 0.5)).abs() + scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1)) + + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): CATEGORY = "_for_testing" @@ -98,6 +125,7 @@ NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, } From 0c55f16c9e66eaa4915e288b34e4f848fb2d949f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 19 Mar 2024 02:34:01 -0400 Subject: [PATCH 12/19] Remove code that should be useless now. --- nodes.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nodes.py b/nodes.py index 453f6e606..d24df200c 100644 --- a/nodes.py +++ b/nodes.py @@ -15,9 +15,6 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch -sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) - - import comfy.diffusers_load import comfy.samplers import comfy.sample From d14bdb18967f7413852a364747c49599de537eec Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 19 Mar 2024 11:17:37 -0400 Subject: [PATCH 13/19] Revert, NOTE: this will be removed again soon please fix your nodes. --- nodes.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nodes.py b/nodes.py index d24df200c..453f6e606 100644 --- a/nodes.py +++ b/nodes.py @@ -15,6 +15,9 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch +sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) + + import comfy.diffusers_load import comfy.samplers import comfy.sample From 150a3e946fe49990454b0b81e26f06cdd9fbab8f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Mar 2024 01:35:59 -0400 Subject: [PATCH 14/19] Make LCM sampler use the model noise scaling function. --- comfy/k_diffusion/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 761c2e0ef..57518c7b6 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n x = denoised if sigmas[i + 1] > 0: - x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1] * s_in, noise_sampler(sigmas[i], sigmas[i + 1]), x) return x From c18a203a8abdd0fce24743a838fe0d0400d8ff09 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Mar 2024 01:29:26 -0400 Subject: [PATCH 15/19] Don't unload model weights for non weight patches. --- comfy/model_management.py | 44 ++++++++++++++++++++++------ comfy/model_patcher.py | 60 ++++++++++++++++++++++++++------------- 2 files changed, 76 insertions(+), 28 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 66fa918bb..74958908a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -273,6 +273,7 @@ class LoadedModel: def __init__(self, model): self.model = model self.device = model.load_device + self.weights_loaded = False def model_memory(self): return self.model.model_size() @@ -289,11 +290,13 @@ class LoadedModel: self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) + load_weights = not self.weights_loaded + try: - if lowvram_model_memory > 0: + if lowvram_model_memory > 0 and load_weights: self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) else: - self.real_model = self.model.patch_model(device_to=patch_model_to) + self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() @@ -302,11 +305,13 @@ class LoadedModel: if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) + self.weights_loaded = True return self.real_model - def model_unload(self): - self.model.unpatch_model(self.model.offload_device) + def model_unload(self, unpatch_weights=True): + self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) + self.weights_loaded = self.weights_loaded and not unpatch_weights def __eq__(self, other): return self.model is other.model @@ -314,15 +319,35 @@ class LoadedModel: def minimum_inference_memory(): return (1024 * 1024 * 1024) -def unload_model_clones(model): +def unload_model_clones(loaded_model, unload_weights_only=True): + model = loaded_model.model + to_unload = [] for i in range(len(current_loaded_models)): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + if len(to_unload) == 0: + return + + same_weights = 0 for i in to_unload: - logging.debug("unload clone {}".format(i)) - current_loaded_models.pop(i).model_unload() + if model.clone_has_same_weights(current_loaded_models[i].model): + same_weights += 1 + + if same_weights == len(to_unload): + unload_weight = False + else: + unload_weight = True + + if unload_weights_only and unload_weight == False: + return + + for i in to_unload: + logging.debug("unload clone {} {}".format(i, unload_weight)) + current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) + + loaded_model.weights_loaded = not unload_weight def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = False @@ -377,13 +402,16 @@ def load_models_gpu(models, memory_required=0): total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model) + unload_model_clones(loaded_model, unload_weights_only=True) #unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + for loaded_model in models_to_load: + unload_model_clones(loaded_model, unload_weights_only=False) #unload the rest of the clones where the weights can stay loaded + for loaded_model in models_to_load: model = loaded_model.model torch_dev = model.load_device diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 475fa812c..aa78302d2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -2,6 +2,7 @@ import torch import copy import inspect import logging +import uuid import comfy.utils import comfy.model_management @@ -25,6 +26,7 @@ class ModelPatcher: self.weight_inplace_update = weight_inplace_update self.model_lowvram = False + self.patches_uuid = uuid.uuid4() def model_size(self): if self.size > 0: @@ -39,10 +41,13 @@ class ModelPatcher: n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] + n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys + n.backup = self.backup + n.object_patches_backup = self.object_patches_backup return n def is_clone(self, other): @@ -50,6 +55,19 @@ class ModelPatcher: return True return False + def clone_has_same_weights(self, clone): + if not self.is_clone(clone): + return False + + if len(self.patches) == 0 and len(clone.patches) == 0: + return True + + if self.patches_uuid == clone.patches_uuid: + if len(self.patches) != len(clone.patches): + logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.") + else: + return True + def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) @@ -154,6 +172,7 @@ class ModelPatcher: current_patches.append((strength_patch, patches[k], strength_model)) self.patches[k] = current_patches + self.patches_uuid = uuid.uuid4() return list(p) def get_key_patches(self, filter_prefix=None): @@ -387,31 +406,32 @@ class ModelPatcher: return weight - def unpatch_model(self, device_to=None): - if self.model_lowvram: - for m in self.model.modules(): - if hasattr(m, "prev_comfy_cast_weights"): - m.comfy_cast_weights = m.prev_comfy_cast_weights - del m.prev_comfy_cast_weights - m.weight_function = None - m.bias_function = None + def unpatch_model(self, device_to=None, unpatch_weights=True): + if unpatch_weights: + if self.model_lowvram: + for m in self.model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None - self.model_lowvram = False + self.model_lowvram = False - keys = list(self.backup.keys()) + keys = list(self.backup.keys()) - if self.weight_inplace_update: - for k in keys: - comfy.utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - comfy.utils.set_attr_param(self.model, k, self.backup[k]) + if self.weight_inplace_update: + for k in keys: + comfy.utils.copy_to_param(self.model, k, self.backup[k]) + else: + for k in keys: + comfy.utils.set_attr_param(self.model, k, self.backup[k]) - self.backup = {} + self.backup.clear() - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to keys = list(self.object_patches_backup.keys()) for k in keys: From 4b9005e949224782236a8b914eae48bc503f1f18 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Mar 2024 13:53:45 -0400 Subject: [PATCH 16/19] Fix regression with model merging. --- comfy/model_management.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 74958908a..11c97f290 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -319,16 +319,14 @@ class LoadedModel: def minimum_inference_memory(): return (1024 * 1024 * 1024) -def unload_model_clones(loaded_model, unload_weights_only=True): - model = loaded_model.model - +def unload_model_clones(model, unload_weights_only=True, force_unload=True): to_unload = [] for i in range(len(current_loaded_models)): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload if len(to_unload) == 0: - return + return None same_weights = 0 for i in to_unload: @@ -340,14 +338,15 @@ def unload_model_clones(loaded_model, unload_weights_only=True): else: unload_weight = True - if unload_weights_only and unload_weight == False: - return + if not force_unload: + if unload_weights_only and unload_weight == False: + return None for i in to_unload: logging.debug("unload clone {} {}".format(i, unload_weight)) current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) - loaded_model.weights_loaded = not unload_weight + return unload_weight def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = False @@ -402,7 +401,7 @@ def load_models_gpu(models, memory_required=0): total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model, unload_weights_only=True) #unload clones where the weights are different + unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: @@ -410,7 +409,9 @@ def load_models_gpu(models, memory_required=0): free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) for loaded_model in models_to_load: - unload_model_clones(loaded_model, unload_weights_only=False) #unload the rest of the clones where the weights can stay loaded + weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded + if weights_unloaded is not None: + loaded_model.weights_loaded = not weights_unloaded for loaded_model in models_to_load: model = loaded_model.model From 5d875d77fe6e31a4b0bc6dc36f0441eba3b6afe1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Mar 2024 20:48:54 -0400 Subject: [PATCH 17/19] Fix regression with lcm not working with batches. --- comfy/k_diffusion/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 57518c7b6..7af016829 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n x = denoised if sigmas[i + 1] > 0: - x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1] * s_in, noise_sampler(sigmas[i], sigmas[i + 1]), x) + x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x) return x From 062483823738ed610d8d074ba63910c90e9d45b7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 21 Mar 2024 14:49:11 -0400 Subject: [PATCH 18/19] Add inverse noise scaling function. --- comfy/model_sampling.py | 3 +++ comfy/samplers.py | 1 + 2 files changed, 4 insertions(+) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index d325f76d9..37976b326 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -20,6 +20,9 @@ class EPS: noise += latent_image return noise + def inverse_noise_scaling(self, sigma, latent): + return latent + class V_PREDICTION(EPS): def calculate_denoised(self, sigma, model_output, model_input): sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) diff --git a/comfy/samplers.py b/comfy/samplers.py index d721cb2e5..3678dc818 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -546,6 +546,7 @@ class KSAMPLER(Sampler): k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) return samples From a28a9dc83684624ee2167c0b92d976bb68f2c606 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Mar 2024 12:56:48 -0400 Subject: [PATCH 19/19] Add an example to use the SaveImageWebsocket node and enable it. --- ...ve.py.disabled => websocket_image_save.py} | 6 +- .../websockets_api_example_ws_images.py | 159 ++++++++++++++++++ 2 files changed, 160 insertions(+), 5 deletions(-) rename custom_nodes/{websocket_image_save.py.disabled => websocket_image_save.py} (84%) create mode 100644 script_examples/websockets_api_example_ws_images.py diff --git a/custom_nodes/websocket_image_save.py.disabled b/custom_nodes/websocket_image_save.py similarity index 84% rename from custom_nodes/websocket_image_save.py.disabled rename to custom_nodes/websocket_image_save.py index b85a5de8b..5aa573642 100644 --- a/custom_nodes/websocket_image_save.py.disabled +++ b/custom_nodes/websocket_image_save.py @@ -10,10 +10,6 @@ import time #binary images on the websocket with a 8 byte header indicating the type #of binary message (first 4 bytes) and the image format (next 4 bytes). -#The reason this node is disabled by default is because there is a small -#issue when using it with the default ComfyUI web interface: When generating -#batches only the last image will be shown in the UI. - #Note that no metadata will be put in the images saved with this node. class SaveImageWebsocket: @@ -28,7 +24,7 @@ class SaveImageWebsocket: OUTPUT_NODE = True - CATEGORY = "image" + CATEGORY = "api/image" def save_images(self, images): pbar = comfy.utils.ProgressBar(images.shape[0]) diff --git a/script_examples/websockets_api_example_ws_images.py b/script_examples/websockets_api_example_ws_images.py new file mode 100644 index 000000000..737488621 --- /dev/null +++ b/script_examples/websockets_api_example_ws_images.py @@ -0,0 +1,159 @@ +#This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without +#them being saved to disk + +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import json +import urllib.request +import urllib.parse + +server_address = "127.0.0.1:8188" +client_id = str(uuid.uuid4()) + +def queue_prompt(prompt): + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + +def get_image(filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: + return response.read() + +def get_history(prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: + return json.loads(response.read()) + +def get_images(ws, prompt): + prompt_id = queue_prompt(prompt)['prompt_id'] + output_images = {} + current_node = "" + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['prompt_id'] == prompt_id: + if data['node'] is None: + break #Execution is done + else: + current_node = data['node'] + else: + if current_node == 'save_image_websocket_node': + images_output = output_images.get(current_node, []) + images_output.append(out[8:]) + output_images[current_node] = images_output + + return output_images + +prompt_text = """ +{ + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": 8, + "denoise": 1, + "latent_image": [ + "5", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "7", + 0 + ], + "positive": [ + "6", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 8566257, + "steps": 20 + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "v1-5-pruned-emaonly.ckpt" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": 512, + "width": 512 + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "masterpiece best quality girl" + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "bad hands" + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "save_image_websocket_node": { + "class_type": "SaveImageWebsocket", + "inputs": { + "images": [ + "8", + 0 + ] + } + } +} +""" + +prompt = json.loads(prompt_text) +#set the text prompt for our positive CLIPTextEncode +prompt["6"]["inputs"]["text"] = "masterpiece best quality man" + +#set the seed for our KSampler node +prompt["3"]["inputs"]["seed"] = 5 + +ws = websocket.WebSocket() +ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) +images = get_images(ws, prompt) + +#Commented out code to display the output images: + +# for node_id in images: +# for image_data in images[node_id]: +# from PIL import Image +# import io +# image = Image.open(io.BytesIO(image_data)) +# image.show() +