mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Merge upstream
This commit is contained in:
parent
eec0661baf
commit
005e370254
@ -8,7 +8,7 @@ import shutil
|
||||
import threading
|
||||
import time
|
||||
|
||||
from comfy.utils import hijack_progress
|
||||
from ..utils import hijack_progress
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from .main_pre import args
|
||||
from .. import model_management
|
||||
|
||||
@ -202,7 +202,7 @@ class ControlNet(ControlBase):
|
||||
super().cleanup()
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module):
|
||||
class Linear(torch.nn.Module, ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
@ -221,7 +221,7 @@ class ControlLoraOps:
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
class Conv2d(torch.nn.Module, ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
|
||||
@ -748,7 +748,7 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
||||
|
||||
x = denoised
|
||||
if sigmas[i + 1] > 0:
|
||||
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -95,7 +95,7 @@ class SC_Prior(LatentFormat):
|
||||
|
||||
class SC_B(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.scale_factor = 1.0 / 0.43
|
||||
self.latent_rgb_factors = [
|
||||
[ 0.1121, 0.2006, 0.1023],
|
||||
[-0.2093, -0.0222, -0.0195],
|
||||
|
||||
@ -163,11 +163,9 @@ class ResBlock(nn.Module):
|
||||
|
||||
|
||||
class StageA(nn.Module):
|
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
|
||||
scale_factor=0.43): # 0.3764
|
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
|
||||
super().__init__()
|
||||
self.c_latent = c_latent
|
||||
self.scale_factor = scale_factor
|
||||
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
||||
|
||||
# Encoder blocks
|
||||
@ -214,12 +212,11 @@ class StageA(nn.Module):
|
||||
x = self.down_blocks(x)
|
||||
if quantize:
|
||||
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
||||
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
|
||||
return qe, x, indices, vq_loss + commit_loss * 0.25
|
||||
else:
|
||||
return x / self.scale_factor
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
x = x * self.scale_factor
|
||||
x = self.up_blocks(x)
|
||||
x = self.out_block(x)
|
||||
return x
|
||||
|
||||
@ -379,6 +379,36 @@ class SVD_img2vid(BaseModel):
|
||||
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
|
||||
return out
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
def encode_adm(self, **kwargs):
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
return flat
|
||||
|
||||
class SV3D_p(SVD_img2vid):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder_512 = Timestep(512)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
|
||||
azimuth = kwargs.get("azimuth", 0)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
|
||||
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0))))
|
||||
out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0))))
|
||||
|
||||
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
|
||||
return torch.cat(out, dim=1)
|
||||
|
||||
|
||||
class Stable_Zero123(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
@ -276,8 +276,8 @@ def module_size(module):
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.model_accelerated = False
|
||||
self.device = model.load_device
|
||||
self.weights_loaded = False
|
||||
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
@ -289,54 +289,33 @@ class LoadedModel:
|
||||
return self.model_memory()
|
||||
|
||||
def model_load(self, lowvram_model_memory=0):
|
||||
patch_model_to = None
|
||||
if lowvram_model_memory == 0:
|
||||
patch_model_to = self.device
|
||||
patch_model_to = self.device
|
||||
|
||||
self.model.model_patches_to(self.device)
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
|
||||
try:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if lowvram_model_memory > 0:
|
||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||
mem_counter = 0
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
module_mem = module_size(m)
|
||||
if mem_counter + module_mem < lowvram_model_memory:
|
||||
m.to(self.device)
|
||||
mem_counter += module_mem
|
||||
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
|
||||
m.to(self.device)
|
||||
mem_counter += module_size(m)
|
||||
logging.warning("lowvram: loaded module regularly {}".format(m))
|
||||
|
||||
self.model_accelerated = True
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
|
||||
def model_unload(self):
|
||||
if self.model_accelerated:
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
|
||||
self.model_accelerated = False
|
||||
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
def model_unload(self, unpatch_weights=True):
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
@ -344,16 +323,35 @@ class LoadedModel:
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024)
|
||||
|
||||
def unload_model_clones(model):
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||
with model_management_lock:
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
if len(to_unload) == 0:
|
||||
return None
|
||||
|
||||
same_weights = 0
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {}".format(i))
|
||||
current_loaded_models.pop(i).model_unload()
|
||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||
same_weights += 1
|
||||
|
||||
if same_weights == len(to_unload):
|
||||
unload_weight = False
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
if not force_unload:
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return None
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {}{}".format(i, unload_weight))
|
||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||
|
||||
return unload_weight
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
with model_management_lock:
|
||||
@ -410,13 +408,18 @@ def load_models_gpu(models, memory_required=0):
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model)
|
||||
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
|
||||
@ -2,6 +2,7 @@ import torch
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from . import utils
|
||||
from . import model_management
|
||||
@ -24,6 +25,8 @@ class ModelPatcher:
|
||||
self.current_device = current_device
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.model_lowvram = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
@ -38,10 +41,13 @@ class ModelPatcher:
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
n.patches_uuid = self.patches_uuid
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
return n
|
||||
|
||||
def is_clone(self, other):
|
||||
@ -49,6 +55,19 @@ class ModelPatcher:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if len(self.patches) == 0 and len(clone.patches) == 0:
|
||||
return True
|
||||
|
||||
if self.patches_uuid == clone.patches_uuid:
|
||||
if len(self.patches) != len(clone.patches):
|
||||
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
||||
else:
|
||||
return True
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
@ -153,6 +172,7 @@ class ModelPatcher:
|
||||
current_patches.append((strength_patch, patches[k], strength_model))
|
||||
self.patches[k] = current_patches
|
||||
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
@ -178,6 +198,27 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
weight = utils.get_attr(self.model, key)
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, key, out_weight)
|
||||
|
||||
def patch_model(self, device_to=None, patch_weights=True):
|
||||
for k in self.object_patches:
|
||||
old = utils.set_attr(self.model, k, self.object_patches[k])
|
||||
@ -191,23 +232,7 @@ class ModelPatcher:
|
||||
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
||||
continue
|
||||
|
||||
weight = model_sd[key]
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
self.patch_weight_to_device(key, device_to)
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
@ -215,6 +240,47 @@ class ModelPatcher:
|
||||
|
||||
return self.model
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
|
||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, model_patcher):
|
||||
self.key = key
|
||||
self.model_patcher = model_patcher
|
||||
def __call__(self, weight):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
|
||||
mem_counter = 0
|
||||
for n, m in self.model.named_modules():
|
||||
lowvram_weight = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = model_management.module_size(m)
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if lowvram_weight:
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(weight_key, self)
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
else:
|
||||
if hasattr(m, "weight"):
|
||||
self.patch_weight_to_device(weight_key, device_to)
|
||||
self.patch_weight_to_device(bias_key, device_to)
|
||||
m.to(device_to)
|
||||
mem_counter += model_management.module_size(m)
|
||||
logging.debug("lowvram: loaded module regularly {}".format(m))
|
||||
|
||||
self.model_lowvram = True
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
for p in patches:
|
||||
alpha = p[0]
|
||||
@ -340,21 +406,32 @@ class ModelPatcher:
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
if unpatch_weights:
|
||||
if self.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
utils.set_attr_param(self.model, k, self.backup[k])
|
||||
self.model_lowvram = False
|
||||
|
||||
self.backup = {}
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
utils.set_attr_param(self.model, k, self.backup[k])
|
||||
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
|
||||
@ -20,6 +20,9 @@ class EPS:
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent
|
||||
|
||||
class V_PREDICTION(EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
|
||||
26
comfy/ops.py
26
comfy/ops.py
@ -24,13 +24,20 @@ def cast_bias_weight(s, input):
|
||||
non_blocking = model_management.device_supports_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
bias = s.bias_function(bias)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear):
|
||||
comfy_cast_weights = False
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -44,8 +51,7 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
comfy_cast_weights = False
|
||||
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -59,8 +65,7 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
comfy_cast_weights = False
|
||||
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -74,8 +79,7 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm):
|
||||
comfy_cast_weights = False
|
||||
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -90,8 +94,7 @@ class disable_weight_init:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
comfy_cast_weights = False
|
||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -109,8 +112,7 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
||||
comfy_cast_weights = False
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
|
||||
@ -549,6 +549,7 @@ class KSAMPLER(Sampler):
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
||||
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
|
||||
return samples
|
||||
|
||||
|
||||
@ -562,11 +563,11 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
|
||||
sampler_function = dpm_fast_function
|
||||
elif sampler_name == "dpm_adaptive":
|
||||
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
|
||||
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
|
||||
sigma_min = sigmas[-1]
|
||||
if sigma_min == 0:
|
||||
sigma_min = sigmas[-2]
|
||||
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
|
||||
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options)
|
||||
sampler_function = dpm_adaptive_function
|
||||
else:
|
||||
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||
|
||||
@ -45,6 +45,11 @@ class SD15(supported_models_base.BASE):
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
|
||||
for p in pop_keys:
|
||||
if p in state_dict:
|
||||
state_dict.pop(p)
|
||||
|
||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
@ -279,6 +284,41 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||
def clip_target(self):
|
||||
return None
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 256,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
vae_key_prefix = ["conditioner.embedders.1.encoder."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SV3D_u(self, device=device)
|
||||
return out
|
||||
|
||||
class SV3D_p(SV3D_u):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 1280,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SV3D_p(self, device=device)
|
||||
return out
|
||||
|
||||
class Stable_Zero123(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
@ -400,5 +440,5 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
||||
return out
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -183,6 +183,28 @@ class KSamplerSelect:
|
||||
sampler = samplers.sampler_object(sampler_name)
|
||||
return (sampler, )
|
||||
|
||||
class SamplerDPMPP_3M_SDE:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"noise_device": (['gpu', 'cpu'], ),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, eta, s_noise, noise_device):
|
||||
if noise_device == 'cpu':
|
||||
sampler_name = "dpmpp_3m_sde"
|
||||
else:
|
||||
sampler_name = "dpmpp_3m_sde_gpu"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise})
|
||||
return (sampler, )
|
||||
|
||||
class SamplerDPMPP_2M_SDE:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -247,6 +269,49 @@ class SamplerEulerAncestral:
|
||||
sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
|
||||
return (sampler, )
|
||||
|
||||
class SamplerLMS:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"order": ("INT", {"default": 4, "min": 1, "max": 100}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, order):
|
||||
sampler = comfy.samplers.ksampler("lms", {"order": order})
|
||||
return (sampler, )
|
||||
|
||||
class SamplerDPMAdaptative:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"order": ("INT", {"default": 3, "min": 2, "max": 3}),
|
||||
"rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise):
|
||||
sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff,
|
||||
"icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta,
|
||||
"s_noise":s_noise })
|
||||
return (sampler, )
|
||||
|
||||
class SamplerCustom:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -308,8 +373,11 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SDTurboScheduler": SDTurboScheduler,
|
||||
"KSamplerSelect": KSamplerSelect,
|
||||
"SamplerEulerAncestral": SamplerEulerAncestral,
|
||||
"SamplerLMS": SamplerLMS,
|
||||
"SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
|
||||
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
|
||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
}
|
||||
|
||||
@ -36,7 +36,7 @@ class RepeatImageBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
@ -51,8 +51,8 @@ class ImageFromBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "frombatch"
|
||||
|
||||
@ -8,7 +8,7 @@ class PerpNeg:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
"empty_conditioning": ("CONDITIONING", ),
|
||||
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
|
||||
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
@ -2,6 +2,7 @@ import torch
|
||||
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
from comfy import utils
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def camera_embeddings(elevation, azimuth):
|
||||
@ -31,8 +32,8 @@ class StableZero123_Conditioning:
|
||||
"width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
@ -64,10 +65,10 @@ class StableZero123_Conditioning_Batched:
|
||||
"width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
"elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
"azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
@ -97,8 +98,49 @@ class StableZero123_Conditioning_Batched:
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
|
||||
|
||||
class SV3D_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||
"init_image": ("IMAGE",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/3d_models"
|
||||
|
||||
def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation):
|
||||
output = clip_vision.encode_image(init_image)
|
||||
pooled = output.image_embeds.unsqueeze(0)
|
||||
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||
encode_pixels = pixels[:,:,:,:3]
|
||||
t = vae.encode(encode_pixels)
|
||||
|
||||
azimuth = 0
|
||||
azimuth_increment = 360 / (max(video_frames, 2) - 1)
|
||||
|
||||
elevations = []
|
||||
azimuths = []
|
||||
for i in range(video_frames):
|
||||
elevations.append(elevation)
|
||||
azimuths.append(azimuth)
|
||||
azimuth += azimuth_increment
|
||||
|
||||
positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]]
|
||||
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]]
|
||||
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StableZero123_Conditioning": StableZero123_Conditioning,
|
||||
"StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
|
||||
"SV3D_Conditioning": SV3D_Conditioning,
|
||||
}
|
||||
|
||||
@ -74,7 +74,7 @@ class StableCascade_StageC_VAEEncode:
|
||||
s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
|
||||
|
||||
c_latent = vae.encode(s[:,:,:,:3])
|
||||
b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4])
|
||||
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
|
||||
return ({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
|
||||
@ -80,6 +80,33 @@ class VideoLinearCFGGuidance:
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return (m, )
|
||||
|
||||
class VideoTriangleCFGGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "sampling/video_models"
|
||||
|
||||
def patch(self, model, min_cfg):
|
||||
def linear_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
period = 1.0
|
||||
values = torch.linspace(0, 1, cond.shape[0], device=cond.device)
|
||||
values = 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||
scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1))
|
||||
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return (m, )
|
||||
|
||||
class ImageOnlyCheckpointSave(nodes_model_merging.CheckpointSave):
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
@ -99,6 +126,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
|
||||
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user