Merge upstream

This commit is contained in:
doctorpangloss 2024-03-21 13:15:36 -07:00
parent eec0661baf
commit 005e370254
18 changed files with 394 additions and 103 deletions

View File

@ -8,7 +8,7 @@ import shutil
import threading
import time
from comfy.utils import hijack_progress
from ..utils import hijack_progress
from .extra_model_paths import load_extra_path_config
from .main_pre import args
from .. import model_management

View File

@ -202,7 +202,7 @@ class ControlNet(ControlBase):
super().cleanup()
class ControlLoraOps:
class Linear(torch.nn.Module):
class Linear(torch.nn.Module, ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
@ -221,7 +221,7 @@ class ControlLoraOps:
else:
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module):
class Conv2d(torch.nn.Module, ops.CastWeightBiasOp):
def __init__(
self,
in_channels,

View File

@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
x = denoised
if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
return x

View File

@ -95,7 +95,7 @@ class SC_Prior(LatentFormat):
class SC_B(LatentFormat):
def __init__(self):
self.scale_factor = 1.0
self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [
[ 0.1121, 0.2006, 0.1023],
[-0.2093, -0.0222, -0.0195],

View File

@ -163,11 +163,9 @@ class ResBlock(nn.Module):
class StageA(nn.Module):
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
scale_factor=0.43): # 0.3764
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
super().__init__()
self.c_latent = c_latent
self.scale_factor = scale_factor
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
# Encoder blocks
@ -214,12 +212,11 @@ class StageA(nn.Module):
x = self.down_blocks(x)
if quantize:
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
return qe, x, indices, vq_loss + commit_loss * 0.25
else:
return x / self.scale_factor
return x
def decode(self, x):
x = x * self.scale_factor
x = self.up_blocks(x)
x = self.out_block(x)
return x

View File

@ -379,6 +379,36 @@ class SVD_img2vid(BaseModel):
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
return out
class SV3D_u(SVD_img2vid):
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
class SV3D_p(SVD_img2vid):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder_512 = Timestep(512)
def encode_adm(self, **kwargs):
augmentation = kwargs.get("augmentation_level", 0)
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
azimuth = kwargs.get("azimuth", 0)
noise = kwargs.get("noise", None)
out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0))))
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0))))
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
return torch.cat(out, dim=1)
class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device)

View File

@ -276,8 +276,8 @@ def module_size(module):
class LoadedModel:
def __init__(self, model):
self.model = model
self.model_accelerated = False
self.device = model.load_device
self.weights_loaded = False
def model_memory(self):
return self.model.model_size()
@ -289,54 +289,33 @@ class LoadedModel:
return self.model_memory()
def model_load(self, lowvram_model_memory=0):
patch_model_to = None
if lowvram_model_memory == 0:
patch_model_to = self.device
patch_model_to = self.device
self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype())
load_weights = not self.weights_loaded
try:
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if lowvram_model_memory > 0:
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = module_size(m)
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
logging.warning("lowvram: loaded module regularly {}".format(m))
self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
self.weights_loaded = True
return self.real_model
def model_unload(self):
if self.model_accelerated:
for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
def model_unload(self, unpatch_weights=True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
def __eq__(self, other):
return self.model is other.model
@ -344,16 +323,35 @@ class LoadedModel:
def minimum_inference_memory():
return (1024 * 1024 * 1024)
def unload_model_clones(model):
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
with model_management_lock:
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return None
same_weights = 0
for i in to_unload:
logging.debug("unload clone {}".format(i))
current_loaded_models.pop(i).model_unload()
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1
if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
for i in to_unload:
logging.debug("unload clone {}{}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
return unload_weight
def free_memory(memory_required, device, keep_loaded=[]):
with model_management_lock:
@ -410,13 +408,18 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model)
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device

View File

@ -2,6 +2,7 @@ import torch
import copy
import inspect
import logging
import uuid
from . import utils
from . import model_management
@ -24,6 +25,8 @@ class ModelPatcher:
self.current_device = current_device
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.patches_uuid = uuid.uuid4()
def model_size(self):
if self.size > 0:
@ -38,10 +41,13 @@ class ModelPatcher:
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
return n
def is_clone(self, other):
@ -49,6 +55,19 @@ class ModelPatcher:
return True
return False
def clone_has_same_weights(self, clone):
if not self.is_clone(clone):
return False
if len(self.patches) == 0 and len(clone.patches) == 0:
return True
if self.patches_uuid == clone.patches_uuid:
if len(self.patches) != len(clone.patches):
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
else:
return True
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
@ -153,6 +172,7 @@ class ModelPatcher:
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
self.patches_uuid = uuid.uuid4()
return list(p)
def get_key_patches(self, filter_prefix=None):
@ -178,6 +198,27 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None):
if key not in self.patches:
return
weight = utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = utils.set_attr(self.model, k, self.object_patches[k])
@ -191,23 +232,7 @@ class ModelPatcher:
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
weight = model_sd[key]
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr_param(self.model, key, out_weight)
del temp_weight
self.patch_weight_to_device(key, device_to)
if device_to is not None:
self.model.to(device_to)
@ -215,6 +240,47 @@ class ModelPatcher:
return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
module_mem = model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if lowvram_weight:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
if bias_key in self.patches:
m.bias_function = LowVramPatch(weight_key, self)
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "weight"):
self.patch_weight_to_device(weight_key, device_to)
self.patch_weight_to_device(bias_key, device_to)
m.to(device_to)
mem_counter += model_management.module_size(m)
logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
@ -340,21 +406,32 @@ class ModelPatcher:
return weight
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
if self.model_lowvram:
for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
if self.weight_inplace_update:
for k in keys:
utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
utils.set_attr_param(self.model, k, self.backup[k])
self.model_lowvram = False
self.backup = {}
keys = list(self.backup.keys())
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
if self.weight_inplace_update:
for k in keys:
utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
utils.set_attr_param(self.model, k, self.backup[k])
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:

View File

@ -20,6 +20,9 @@ class EPS:
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))

View File

@ -24,13 +24,20 @@ def cast_bias_weight(s, input):
non_blocking = model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.bias_function is not None:
bias = s.bias_function(bias)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.weight_function is not None:
weight = s.weight_function(weight)
return weight, bias
class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = None
bias_function = None
class disable_weight_init:
class Linear(torch.nn.Linear):
comfy_cast_weights = False
class Linear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -44,8 +51,7 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d):
comfy_cast_weights = False
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -59,8 +65,7 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -74,8 +79,7 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -90,8 +94,7 @@ class disable_weight_init:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -109,8 +112,7 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d):
comfy_cast_weights = False
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None

View File

@ -549,6 +549,7 @@ class KSAMPLER(Sampler):
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
return samples
@ -562,11 +563,11 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive":
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options)
sampler_function = dpm_adaptive_function
else:
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))

View File

@ -45,6 +45,11 @@ class SD15(supported_models_base.BASE):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
for p in pop_keys:
if p in state_dict:
state_dict.pop(p)
replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@ -279,6 +284,41 @@ class SVD_img2vid(supported_models_base.BASE):
def clip_target(self):
return None
class SV3D_u(SVD_img2vid):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 256,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
vae_key_prefix = ["conditioner.embedders.1.encoder."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_u(self, device=device)
return out
class SV3D_p(SV3D_u):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 1280,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_p(self, device=device)
return out
class Stable_Zero123(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
@ -400,5 +440,5 @@ class Stable_Cascade_B(Stable_Cascade_C):
return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models += [SVD_img2vid]

View File

@ -183,6 +183,28 @@ class KSamplerSelect:
sampler = samplers.sampler_object(sampler_name)
return (sampler, )
class SamplerDPMPP_3M_SDE:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"noise_device": (['gpu', 'cpu'], ),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise, noise_device):
if noise_device == 'cpu':
sampler_name = "dpmpp_3m_sde"
else:
sampler_name = "dpmpp_3m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerDPMPP_2M_SDE:
@classmethod
def INPUT_TYPES(s):
@ -247,6 +269,49 @@ class SamplerEulerAncestral:
sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerLMS:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"order": ("INT", {"default": 4, "min": 1, "max": 100}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, order):
sampler = comfy.samplers.ksampler("lms", {"order": order})
return (sampler, )
class SamplerDPMAdaptative:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"order": ("INT", {"default": 3, "min": 2, "max": 3}),
"rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise):
sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff,
"icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta,
"s_noise":s_noise })
return (sampler, )
class SamplerCustom:
@classmethod
def INPUT_TYPES(s):
@ -308,8 +373,11 @@ NODE_CLASS_MAPPINGS = {
"SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect,
"SamplerEulerAncestral": SamplerEulerAncestral,
"SamplerLMS": SamplerLMS,
"SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas,
}

View File

@ -36,7 +36,7 @@ class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
@ -51,8 +51,8 @@ class ImageFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "frombatch"

View File

@ -8,7 +8,7 @@ class PerpNeg:
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"empty_conditioning": ("CONDITIONING", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

View File

@ -2,6 +2,7 @@ import torch
from comfy.nodes.common import MAX_RESOLUTION
from comfy import utils
import comfy.utils
def camera_embeddings(elevation, azimuth):
@ -31,8 +32,8 @@ class StableZero123_Conditioning:
"width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@ -64,10 +65,10 @@ class StableZero123_Conditioning_Batched:
"width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@ -97,8 +98,49 @@ class StableZero123_Conditioning_Batched:
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
class SV3D_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
azimuth = 0
azimuth_increment = 360 / (max(video_frames, 2) - 1)
elevations = []
azimuths = []
for i in range(video_frames):
elevations.append(elevation)
azimuths.append(azimuth)
azimuth += azimuth_increment
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning,
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
"SV3D_Conditioning": SV3D_Conditioning,
}

View File

@ -74,7 +74,7 @@ class StableCascade_StageC_VAEEncode:
s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({
"samples": c_latent,
}, {

View File

@ -80,6 +80,33 @@ class VideoLinearCFGGuidance:
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
class VideoTriangleCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
period = 1.0
values = torch.linspace(0, 1, cond.shape[0], device=cond.device)
values = 2 * (values / period - torch.floor(values / period + 0.5)).abs()
scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
class ImageOnlyCheckpointSave(nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"
@ -99,6 +126,7 @@ NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
}