diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index f144738cd..859b27e3d 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -8,7 +8,7 @@ import shutil import threading import time -from comfy.utils import hijack_progress +from ..utils import hijack_progress from .extra_model_paths import load_extra_path_config from .main_pre import args from .. import model_management diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d6a6d227c..a6a5fadf6 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -202,7 +202,7 @@ class ControlNet(ControlBase): super().cleanup() class ControlLoraOps: - class Linear(torch.nn.Module): + class Linear(torch.nn.Module, ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} @@ -221,7 +221,7 @@ class ControlLoraOps: else: return torch.nn.functional.linear(input, weight, bias) - class Conv2d(torch.nn.Module): + class Conv2d(torch.nn.Module, ops.CastWeightBiasOp): def __init__( self, in_channels, diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 761c2e0ef..7af016829 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n x = denoised if sigmas[i + 1] > 0: - x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x) return x diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 674364e72..4ca466d9a 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -95,7 +95,7 @@ class SC_Prior(LatentFormat): class SC_B(LatentFormat): def __init__(self): - self.scale_factor = 1.0 + self.scale_factor = 1.0 / 0.43 self.latent_rgb_factors = [ [ 0.1121, 0.2006, 0.1023], [-0.2093, -0.0222, -0.0195], diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 260ccfc0b..ca8867eaf 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -163,11 +163,9 @@ class ResBlock(nn.Module): class StageA(nn.Module): - def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, - scale_factor=0.43): # 0.3764 + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192): super().__init__() self.c_latent = c_latent - self.scale_factor = scale_factor c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] # Encoder blocks @@ -214,12 +212,11 @@ class StageA(nn.Module): x = self.down_blocks(x) if quantize: qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) - return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + return qe, x, indices, vq_loss + commit_loss * 0.25 else: - return x / self.scale_factor + return x def decode(self, x): - x = x * self.scale_factor x = self.up_blocks(x) x = self.out_block(x) return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 95a6829e5..d0e8193a2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -379,6 +379,36 @@ class SVD_img2vid(BaseModel): out['num_video_frames'] = conds.CONDConstant(noise.shape[0]) return out +class SV3D_u(SVD_img2vid): + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + +class SV3D_p(SVD_img2vid): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder_512 = Timestep(512) + + def encode_adm(self, **kwargs): + augmentation = kwargs.get("augmentation_level", 0) + elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here + azimuth = kwargs.get("azimuth", 0) + noise = kwargs.get("noise", None) + + out = [] + out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0)))) + out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0)))) + + out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) + return torch.cat(out, dim=1) + + class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) diff --git a/comfy/model_management.py b/comfy/model_management.py index e4891afdf..8954973cb 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -276,8 +276,8 @@ def module_size(module): class LoadedModel: def __init__(self, model): self.model = model - self.model_accelerated = False self.device = model.load_device + self.weights_loaded = False def model_memory(self): return self.model.model_size() @@ -289,54 +289,33 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0): - patch_model_to = None - if lowvram_model_memory == 0: - patch_model_to = self.device + patch_model_to = self.device self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) + load_weights = not self.weights_loaded + try: - self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU + if lowvram_model_memory > 0 and load_weights: + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) + else: + self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() raise e - if lowvram_model_memory > 0: - logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) - mem_counter = 0 - for m in self.real_model.modules(): - if hasattr(m, "comfy_cast_weights"): - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - module_mem = module_size(m) - if mem_counter + module_mem < lowvram_model_memory: - m.to(self.device) - mem_counter += module_mem - elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode - m.to(self.device) - mem_counter += module_size(m) - logging.warning("lowvram: loaded module regularly {}".format(m)) - - self.model_accelerated = True - if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) + self.weights_loaded = True return self.real_model - def model_unload(self): - if self.model_accelerated: - for m in self.real_model.modules(): - if hasattr(m, "prev_comfy_cast_weights"): - m.comfy_cast_weights = m.prev_comfy_cast_weights - del m.prev_comfy_cast_weights - - self.model_accelerated = False - - self.model.unpatch_model(self.model.offload_device) + def model_unload(self, unpatch_weights=True): + self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) + self.weights_loaded = self.weights_loaded and not unpatch_weights def __eq__(self, other): return self.model is other.model @@ -344,16 +323,35 @@ class LoadedModel: def minimum_inference_memory(): return (1024 * 1024 * 1024) -def unload_model_clones(model): +def unload_model_clones(model, unload_weights_only=True, force_unload=True): with model_management_lock: to_unload = [] for i in range(len(current_loaded_models)): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + if len(to_unload) == 0: + return None + + same_weights = 0 for i in to_unload: - logging.debug("unload clone {}".format(i)) - current_loaded_models.pop(i).model_unload() + if model.clone_has_same_weights(current_loaded_models[i].model): + same_weights += 1 + + if same_weights == len(to_unload): + unload_weight = False + else: + unload_weight = True + + if not force_unload: + if unload_weights_only and unload_weight == False: + return None + + for i in to_unload: + logging.debug("unload clone {}{}".format(i, unload_weight)) + current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) + + return unload_weight def free_memory(memory_required, device, keep_loaded=[]): with model_management_lock: @@ -410,13 +408,18 @@ def load_models_gpu(models, memory_required=0): total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model) + unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + for loaded_model in models_to_load: + weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded + if weights_unloaded is not None: + loaded_model.weights_loaded = not weights_unloaded + for loaded_model in models_to_load: model = loaded_model.model torch_dev = model.load_device diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3ca0d64a8..6d679aa85 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -2,6 +2,7 @@ import torch import copy import inspect import logging +import uuid from . import utils from . import model_management @@ -24,6 +25,8 @@ class ModelPatcher: self.current_device = current_device self.weight_inplace_update = weight_inplace_update + self.model_lowvram = False + self.patches_uuid = uuid.uuid4() def model_size(self): if self.size > 0: @@ -38,10 +41,13 @@ class ModelPatcher: n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] + n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys + n.backup = self.backup + n.object_patches_backup = self.object_patches_backup return n def is_clone(self, other): @@ -49,6 +55,19 @@ class ModelPatcher: return True return False + def clone_has_same_weights(self, clone): + if not self.is_clone(clone): + return False + + if len(self.patches) == 0 and len(clone.patches) == 0: + return True + + if self.patches_uuid == clone.patches_uuid: + if len(self.patches) != len(clone.patches): + logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.") + else: + return True + def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) @@ -153,6 +172,7 @@ class ModelPatcher: current_patches.append((strength_patch, patches[k], strength_model)) self.patches[k] = current_patches + self.patches_uuid = uuid.uuid4() return list(p) def get_key_patches(self, filter_prefix=None): @@ -178,6 +198,27 @@ class ModelPatcher: sd.pop(k) return sd + def patch_weight_to_device(self, key, device_to=None): + if key not in self.patches: + return + + weight = utils.get_attr(self.model, key) + + inplace_update = self.weight_inplace_update + + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + utils.copy_to_param(self.model, key, out_weight) + else: + utils.set_attr_param(self.model, key, out_weight) + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: old = utils.set_attr(self.model, k, self.object_patches[k]) @@ -191,23 +232,7 @@ class ModelPatcher: logging.warning("could not patch. key doesn't exist in model: {}".format(key)) continue - weight = model_sd[key] - - inplace_update = self.weight_inplace_update - - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) - - if device_to is not None: - temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - utils.copy_to_param(self.model, key, out_weight) - else: - utils.set_attr_param(self.model, key, out_weight) - del temp_weight + self.patch_weight_to_device(key, device_to) if device_to is not None: self.model.to(device_to) @@ -215,6 +240,47 @@ class ModelPatcher: return self.model + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): + self.patch_model(device_to, patch_weights=False) + + logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) + class LowVramPatch: + def __init__(self, key, model_patcher): + self.key = key + self.model_patcher = model_patcher + def __call__(self, weight): + return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) + + mem_counter = 0 + for n, m in self.model.named_modules(): + lowvram_weight = False + if hasattr(m, "comfy_cast_weights"): + module_mem = model_management.module_size(m) + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + + if lowvram_weight: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self) + if bias_key in self.patches: + m.bias_function = LowVramPatch(weight_key, self) + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + else: + if hasattr(m, "weight"): + self.patch_weight_to_device(weight_key, device_to) + self.patch_weight_to_device(bias_key, device_to) + m.to(device_to) + mem_counter += model_management.module_size(m) + logging.debug("lowvram: loaded module regularly {}".format(m)) + + self.model_lowvram = True + return self.model + def calculate_weight(self, patches, weight, key): for p in patches: alpha = p[0] @@ -340,21 +406,32 @@ class ModelPatcher: return weight - def unpatch_model(self, device_to=None): - keys = list(self.backup.keys()) + def unpatch_model(self, device_to=None, unpatch_weights=True): + if unpatch_weights: + if self.model_lowvram: + for m in self.model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None - if self.weight_inplace_update: - for k in keys: - utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - utils.set_attr_param(self.model, k, self.backup[k]) + self.model_lowvram = False - self.backup = {} + keys = list(self.backup.keys()) - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + if self.weight_inplace_update: + for k in keys: + utils.copy_to_param(self.model, k, self.backup[k]) + else: + for k in keys: + utils.set_attr_param(self.model, k, self.backup[k]) + + self.backup.clear() + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to keys = list(self.object_patches_backup.keys()) for k in keys: diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 28442ade1..2c4f1f4c0 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -20,6 +20,9 @@ class EPS: noise += latent_image return noise + def inverse_noise_scaling(self, sigma, latent): + return latent + class V_PREDICTION(EPS): def calculate_denoised(self, sigma, model_output, model_input): sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) diff --git a/comfy/ops.py b/comfy/ops.py index b4760cf1b..a278ff755 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,13 +24,20 @@ def cast_bias_weight(s, input): non_blocking = model_management.device_supports_non_blocking(input.device) if s.bias is not None: bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.bias_function is not None: + bias = s.bias_function(bias) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + if s.weight_function is not None: + weight = s.weight_function(weight) return weight, bias +class CastWeightBiasOp: + comfy_cast_weights = False + weight_function = None + bias_function = None class disable_weight_init: - class Linear(torch.nn.Linear): - comfy_cast_weights = False + class Linear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None @@ -44,8 +51,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class Conv2d(torch.nn.Conv2d): - comfy_cast_weights = False + class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): def reset_parameters(self): return None @@ -59,8 +65,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class Conv3d(torch.nn.Conv3d): - comfy_cast_weights = False + class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): def reset_parameters(self): return None @@ -74,8 +79,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class GroupNorm(torch.nn.GroupNorm): - comfy_cast_weights = False + class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -90,8 +94,7 @@ class disable_weight_init: return super().forward(*args, **kwargs) - class LayerNorm(torch.nn.LayerNorm): - comfy_cast_weights = False + class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -109,8 +112,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class ConvTranspose2d(torch.nn.ConvTranspose2d): - comfy_cast_weights = False + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None diff --git a/comfy/samplers.py b/comfy/samplers.py index 3539bd3a7..e5a42a8b3 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -549,6 +549,7 @@ class KSAMPLER(Sampler): k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) return samples @@ -562,11 +563,11 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) sampler_function = dpm_fast_function elif sampler_name == "dpm_adaptive": - def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): + def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options): sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] - return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) + return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options) sampler_function = dpm_adaptive_function else: sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 375821032..2ce9736b7 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -45,6 +45,11 @@ class SD15(supported_models_base.BASE): return state_dict def process_clip_state_dict_for_saving(self, state_dict): + pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] + for p in pop_keys: + if p in state_dict: + state_dict.pop(p) + replace_prefix = {"clip_l.": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -279,6 +284,41 @@ class SVD_img2vid(supported_models_base.BASE): def clip_target(self): return None +class SV3D_u(SVD_img2vid): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 256, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + vae_key_prefix = ["conditioner.embedders.1.encoder."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_u(self, device=device) + return out + +class SV3D_p(SV3D_u): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 1280, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SV3D_p(self, device=device) + return out + class Stable_Zero123(supported_models_base.BASE): unet_config = { "context_dim": 768, @@ -400,5 +440,5 @@ class Stable_Cascade_B(Stable_Cascade_C): return out -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] models += [SVD_img2vid] diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index b08243675..5a99a6e9e 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -183,6 +183,28 @@ class KSamplerSelect: sampler = samplers.sampler_object(sampler_name) return (sampler, ) +class SamplerDPMPP_3M_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + class SamplerDPMPP_2M_SDE: @classmethod def INPUT_TYPES(s): @@ -247,6 +269,49 @@ class SamplerEulerAncestral: sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) return (sampler, ) +class SamplerLMS: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 4, "min": 1, "max": 100}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order): + sampler = comfy.samplers.ksampler("lms", {"order": order}) + return (sampler, ) + +class SamplerDPMAdaptative: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 3, "min": 2, "max": 3}), + "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) + return (sampler, ) + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -308,8 +373,11 @@ NODE_CLASS_MAPPINGS = { "SDTurboScheduler": SDTurboScheduler, "KSamplerSelect": KSamplerSelect, "SamplerEulerAncestral": SamplerEulerAncestral, + "SamplerLMS": SamplerLMS, + "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, + "SamplerDPMAdaptative": SamplerDPMAdaptative, "SplitSigmas": SplitSigmas, "FlipSigmas": FlipSigmas, } diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py index a753feb78..eade3605d 100644 --- a/comfy_extras/nodes/nodes_images.py +++ b/comfy_extras/nodes/nodes_images.py @@ -36,7 +36,7 @@ class RepeatImageBatch: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "repeat" @@ -51,8 +51,8 @@ class ImageFromBatch: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), - "length": ("INT", {"default": 1, "min": 1, "max": 64}), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), + "length": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "frombatch" diff --git a/comfy_extras/nodes/nodes_perpneg.py b/comfy_extras/nodes/nodes_perpneg.py index ad07bb741..43405a2ab 100644 --- a/comfy_extras/nodes/nodes_perpneg.py +++ b/comfy_extras/nodes/nodes_perpneg.py @@ -8,7 +8,7 @@ class PerpNeg: def INPUT_TYPES(s): return {"required": {"model": ("MODEL", ), "empty_conditioning": ("CONDITIONING", ), - "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" diff --git a/comfy_extras/nodes/nodes_stable3d.py b/comfy_extras/nodes/nodes_stable3d.py index b8910581c..1c0145909 100644 --- a/comfy_extras/nodes/nodes_stable3d.py +++ b/comfy_extras/nodes/nodes_stable3d.py @@ -2,6 +2,7 @@ import torch from comfy.nodes.common import MAX_RESOLUTION from comfy import utils +import comfy.utils def camera_embeddings(elevation, azimuth): @@ -31,8 +32,8 @@ class StableZero123_Conditioning: "width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -64,10 +65,10 @@ class StableZero123_Conditioning_Batched: "width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), - "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -97,8 +98,49 @@ class StableZero123_Conditioning_Batched: latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) +class SV3D_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + azimuth = 0 + azimuth_increment = 360 / (max(video_frames, 2) - 1) + + elevations = [] + azimuths = [] + for i in range(video_frames): + elevations.append(elevation) + azimuths.append(azimuth) + azimuth += azimuth_increment + + positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + NODE_CLASS_MAPPINGS = { "StableZero123_Conditioning": StableZero123_Conditioning, "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, + "SV3D_Conditioning": SV3D_Conditioning, } diff --git a/comfy_extras/nodes/nodes_stable_cascade.py b/comfy_extras/nodes/nodes_stable_cascade.py index e0bca8bf9..c0d567863 100644 --- a/comfy_extras/nodes/nodes_stable_cascade.py +++ b/comfy_extras/nodes/nodes_stable_cascade.py @@ -74,7 +74,7 @@ class StableCascade_StageC_VAEEncode: s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) c_latent = vae.encode(s[:,:,:,:3]) - b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4]) + b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) return ({ "samples": c_latent, }, { diff --git a/comfy_extras/nodes/nodes_video_model.py b/comfy_extras/nodes/nodes_video_model.py index 187495233..ce6b469ea 100644 --- a/comfy_extras/nodes/nodes_video_model.py +++ b/comfy_extras/nodes/nodes_video_model.py @@ -80,6 +80,33 @@ class VideoLinearCFGGuidance: m.set_model_sampler_cfg_function(linear_cfg) return (m, ) +class VideoTriangleCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + period = 1.0 + values = torch.linspace(0, 1, cond.shape[0], device=cond.device) + values = 2 * (values / period - torch.floor(values / period + 0.5)).abs() + scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1)) + + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + class ImageOnlyCheckpointSave(nodes_model_merging.CheckpointSave): CATEGORY = "_for_testing" @@ -99,6 +126,7 @@ NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, }