Merge remote-tracking branch 'upstream/master' into node_expansion

This commit is contained in:
Jacob Segal 2023-07-29 00:08:44 -07:00
commit 95c8e22fae
23 changed files with 575 additions and 260 deletions

View File

@ -93,8 +93,8 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements: This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6 -r requirements.txt```
### NVIDIA ### NVIDIA

View File

@ -42,7 +42,7 @@ parser.add_argument("--auto-launch", action="store_true", help="Automatically la
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
cm_group = parser.add_mutually_exclusive_group() cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Enable cudaMallocAsync.") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
@ -84,6 +84,8 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
args = parser.parse_args() args = parser.parse_args()
if args.windows_standalone_build: if args.windows_standalone_build:

View File

@ -148,6 +148,10 @@ vae_conversion_map_attn = [
("q.", "query."), ("q.", "query."),
("k.", "key."), ("k.", "key."),
("v.", "value."), ("v.", "value."),
("q.", "to_q."),
("k.", "to_k."),
("v.", "to_v."),
("proj_out.", "to_out.0."),
("proj_out.", "proj_attn."), ("proj_out.", "proj_attn."),
] ]

View File

@ -91,7 +91,9 @@ class DiscreteSchedule(nn.Module):
return log_sigma.exp() return log_sigma.exp()
def predict_eps_discrete_timestep(self, input, t, **kwargs): def predict_eps_discrete_timestep(self, input, t, **kwargs):
sigma = self.t_to_sigma(t.round()) if t.dtype != torch.int64 and t.dtype != torch.int32:
t = t.round()
sigma = self.t_to_sigma(t)
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)

View File

@ -3,7 +3,6 @@ import math
from scipy import integrate from scipy import integrate
import torch import torch
from torch import nn from torch import nn
from torchdiffeq import odeint
import torchsde import torchsde
from tqdm.auto import trange, tqdm from tqdm.auto import trange, tqdm
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
return x return x
@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
return ll_prior + delta_ll, {'fevals': fevals}
class PIDStepSizeController: class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control.""" """A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):

View File

@ -164,7 +164,6 @@ class SDXLRefiner(BaseModel):
else: else:
aesthetic_score = kwargs.get("aesthetic_score", 6) aesthetic_score = kwargs.get("aesthetic_score", 6)
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
out = [] out = []
out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([width])))
@ -188,7 +187,6 @@ class SDXL(BaseModel):
target_width = kwargs.get("target_width", width) target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height) target_height = kwargs.get("target_height", height)
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
out = [] out = []
out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([width])))

View File

@ -118,3 +118,57 @@ def model_config_from_unet_config(unet_config):
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
return model_config_from_unet_config(unet_config) return model_config_from_unet_config(unet_config)
def model_config_from_diffusers_unet(state_dict, use_fp16):
match = {}
match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
elif "add_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
matches = False
break
if matches:
return model_config_from_unet_config(unet_config)
return None

View File

@ -49,6 +49,7 @@ except:
try: try:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
cpu_state = CPUState.MPS cpu_state = CPUState.MPS
import torch.mps
except: except:
pass pass
@ -280,19 +281,23 @@ def load_model_gpu(model):
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
real_model = model.model real_model = model.model
patch_model_to = None
if vram_set_state == VRAMState.DISABLED: if vram_set_state == VRAMState.DISABLED:
pass pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False model_accelerated = False
real_model.to(torch_dev) patch_model_to = torch_dev
try: try:
real_model = model.patch_model() real_model = model.patch_model(device_to=patch_model_to)
except Exception as e: except Exception as e:
model.unpatch_model() model.unpatch_model()
unload_model() unload_model()
raise e raise e
if patch_model_to is not None:
real_model.to(torch_dev)
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
@ -529,7 +534,7 @@ def should_use_fp16(device=None, model_params=0):
return False return False
#FP16 is just broken on these cards #FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"] nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450"]
for x in nvidia_16_series: for x in nvidia_16_series:
if x in props.name: if x in props.name:
return False return False

View File

@ -17,6 +17,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
if 'timestep_start' in cond[1]:
timestep_start = cond[1]['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in cond[1]:
timestep_end = cond[1]['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in cond[1]: if 'area' in cond[1]:
area = cond[1]['area'] area = cond[1]['area']
if 'strength' in cond[1]: if 'strength' in cond[1]:
@ -248,7 +256,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c['transformer_options'] = transformer_options c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, **c).chunk(batch_chunks) if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
del input_x del input_x
model_management.throw_exception_if_processing_interrupted() model_management.throw_exception_if_processing_interrupted()
@ -425,6 +436,35 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy() n = c[1].copy()
conds += [[smallest[0], n]] conds += [[smallest[0], n]]
def calculate_start_end_timesteps(model, conds):
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
if 'start_percent' in x[1]:
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
if 'end_percent' in x[1]:
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
if (timestep_start is not None) or (timestep_end is not None):
n = x[1].copy()
if (timestep_start is not None):
n['timestep_start'] = timestep_start
if (timestep_end is not None):
n['timestep_end'] = timestep_end
conds[t] = [x[0], n]
def pre_run_control(model, conds):
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
if 'control' in x[1]:
x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []
cond_other = [] cond_other = []
@ -568,13 +608,18 @@ class KSampler:
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
calculate_start_end_timesteps(self.model_wrap, negative)
calculate_start_end_timesteps(self.model_wrap, positive)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
create_cond_with_same_area_if_none(negative, c) create_cond_with_same_area_if_none(negative, c)
for c in negative: for c in negative:
create_cond_with_same_area_if_none(positive, c) create_cond_with_same_area_if_none(positive, c)
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) pre_run_control(self.model_wrap, negative + positive)
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.is_adm(): if self.model.is_adm():

View File

@ -170,6 +170,8 @@ def model_lora_keys_clip(model, key_map={}):
if k in sdk: if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk: if k in sdk:
@ -202,6 +204,14 @@ def model_lora_keys_unet(model, key_map={}):
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
return key_map return key_map
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0): def __init__(self, model, load_device, offload_device, size=0):
self.size = size self.size = size
@ -330,7 +340,7 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_model(self): def patch_model(self, device_to=None):
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
for key in self.patches: for key in self.patches:
if key not in model_sd: if key not in model_sd:
@ -340,10 +350,14 @@ class ModelPatcher:
weight = model_sd[key] weight = model_sd[key]
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.to(self.offload_device, copy=True) self.backup[key] = weight.to(self.offload_device)
temp_weight = weight.to(torch.float32, copy=True) if device_to is not None:
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight)
del temp_weight del temp_weight
return self.model return self.model
@ -376,7 +390,10 @@ class ModelPatcher:
mat3 = v[3].float().to(weight.device) mat3 = v[3].float().to(weight.device)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr elif len(v) == 8: #lokr
w1 = v[0] w1 = v[0]
w2 = v[1] w2 = v[1]
@ -407,7 +424,10 @@ class ModelPatcher:
if v[2] is not None and dim is not None: if v[2] is not None and dim is not None:
alpha *= v[2] / dim alpha *= v[2] / dim
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha else: #loha
w1a = v[0] w1a = v[0]
w1b = v[1] w1b = v[1]
@ -424,18 +444,15 @@ class ModelPatcher:
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
return weight return weight
def unpatch_model(self): def unpatch_model(self):
keys = list(self.backup.keys()) keys = list(self.backup.keys())
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
for k in keys: for k in keys:
set_attr(self.model, k, self.backup[k]) set_attr(self.model, k, self.backup[k])
@ -658,16 +675,57 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else: else:
return torch.cat([tensor] * batched_number, dim=0) return torch.cat([tensor] * batched_number, dim=0)
class ControlNet: class ControlBase:
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, device=None):
self.control_model = control_model
self.cond_hint_original = None self.cond_hint_original = None
self.cond_hint = None self.cond_hint = None
self.strength = 1.0 self.strength = 1.0
self.timestep_percent_range = (1.0, 0.0)
self.timestep_range = None
if device is None: if device is None:
device = model_management.get_torch_device() device = model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
return self
def pre_run(self, model, percent_to_timestep_function):
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
if self.previous_controlnet is not None:
self.previous_controlnet.pre_run(model, percent_to_timestep_function)
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.timestep_range = None
def get_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device)
self.control_model = control_model
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
@ -675,6 +733,13 @@ class ControlNet:
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return {}
output_dtype = x_noisy.dtype output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
@ -722,37 +787,64 @@ class ControlNet:
out['input'] = control_prev['input'] out['input'] = control_prev['input']
return out return out
def set_cond_hint(self, cond_hint, strength=1.0):
self.cond_hint_original = cond_hint
self.strength = strength
return self
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
def copy(self): def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c.cond_hint_original = self.cond_hint_original self.copy_to(c)
c.strength = self.strength
return c return c
def get_models(self): def get_models(self):
out = [] out = super().get_models()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
out.append(self.control_model) out.append(self.control_model)
return out return out
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
k_in = "controlnet_down_blocks.{}{}".format(count, s)
k_out = "zero_convs.{}.0{}".format(count, s)
if k_in not in controlnet_data:
loop = False
break
diffusers_keys[k_in] = k_out
count += 1
count = 0
loop = True
while loop:
suffix = [".weight", ".bias"]
for s in suffix:
if count == 0:
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
else:
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
k_out = "input_hint_block.{}{}".format(count * 2, s)
if k_in not in controlnet_data:
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
loop = False
diffusers_keys[k_in] = k_out
count += 1
new_sd = {}
for k in diffusers_keys:
if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
pth = False pth = False
key = 'zero_convs.0.0.weight' key = 'zero_convs.0.0.weight'
@ -768,9 +860,9 @@ def load_controlnet(ckpt_path, model=None):
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
return net return net
use_fp16 = model_management.should_use_fp16() if controlnet_config is None:
use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = 3 controlnet_config["hint_channels"] = 3
control_model = cldm.ControlNet(**controlnet_config) control_model = cldm.ControlNet(**controlnet_config)
@ -810,24 +902,25 @@ def load_controlnet(ckpt_path, model=None):
control = ControlNet(control_model, global_average_pooling=global_average_pooling) control = ControlNet(control_model, global_average_pooling=global_average_pooling)
return control return control
class T2IAdapter: class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, device=None): def __init__(self, t2i_model, channels_in, device=None):
super().__init__(device)
self.t2i_model = t2i_model self.t2i_model = t2i_model
self.channels_in = channels_in self.channels_in = channels_in
self.strength = 1.0
if device is None:
device = model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.control_input = None self.control_input = None
self.cond_hint_original = None
self.cond_hint = None
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return {}
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
@ -872,33 +965,11 @@ class T2IAdapter:
out['output'] = control_prev['output'] out['output'] = control_prev['output']
return out return out
def set_cond_hint(self, cond_hint, strength=1.0):
self.cond_hint_original = cond_hint
self.strength = strength
return self
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def copy(self): def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in) c = T2IAdapter(self.t2i_model, self.channels_in)
c.cond_hint_original = self.cond_hint_original self.copy_to(c)
c.strength = self.strength
return c return c
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
def get_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data):
keys = t2i_data.keys() keys = t2i_data.keys()
@ -1128,66 +1199,24 @@ def load_unet(unet_path): #load unet in diffusers format
parameters = calculate_parameters(sd, "") parameters = calculate_parameters(sd, "")
fp16 = model_management.should_use_fp16(model_params=parameters) fp16 = model_management.should_use_fp16(model_params=parameters)
match = {} model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] if model_config is None:
match["model_channels"] = sd["conv_in.weight"].shape[0] print("ERROR UNSUPPORTED UNET", unet_path)
match["in_channels"] = sd["conv_in.weight"].shape[1] return None
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in sd:
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, diffusers_keys = utils.unet_to_diffusers(model_config.unet_config)
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, new_sd = {}
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384, for k in diffusers_keys:
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], if k in sd:
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} new_sd[diffusers_keys[k]] = sd.pop(k)
else:
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, print(diffusers_keys[k], k)
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, offload_device = model_management.unet_offload_device()
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], model = model_config.get_model(new_sd, "")
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} model = model.to(offload_device)
model.load_model_weights(new_sd, "")
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
print("match", match)
for unet_config in supported_models:
matches = True
for k in match:
if match[k] != unet_config[k]:
matches = False
break
if matches:
diffusers_keys = utils.unet_to_diffusers(unet_config)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model_config = model_detection.model_config_from_unet_config(unet_config)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
try: try:

View File

@ -126,7 +126,8 @@ class SDXLRefiner(supported_models_base.BASE):
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {} replace_prefix = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
replace_prefix["clip_g"] = "conditioner.embedders.0.model" replace_prefix["clip_g"] = "conditioner.embedders.0.model"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g return state_dict_g
@ -171,7 +172,8 @@ class SDXL(supported_models_base.BASE):
replace_prefix = {} replace_prefix = {}
keys_to_replace = {} keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
for k in state_dict: for k in state_dict:
if k.startswith("clip_l"): if k.startswith("clip_l"):
state_dict_g[k] = state_dict[k] state_dict_g[k] = state_dict[k]

View File

@ -120,20 +120,24 @@ UNET_MAP_RESNET = {
} }
UNET_MAP_BASIC = { UNET_MAP_BASIC = {
"label_emb.0.0.weight": "class_embedding.linear_1.weight", ("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
"label_emb.0.0.bias": "class_embedding.linear_1.bias", ("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
"label_emb.0.2.weight": "class_embedding.linear_2.weight", ("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
"label_emb.0.2.bias": "class_embedding.linear_2.bias", ("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
"input_blocks.0.0.weight": "conv_in.weight", ("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
"input_blocks.0.0.bias": "conv_in.bias", ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
"out.0.weight": "conv_norm_out.weight", ("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
"out.0.bias": "conv_norm_out.bias", ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
"out.2.weight": "conv_out.weight", ("input_blocks.0.0.weight", "conv_in.weight"),
"out.2.bias": "conv_out.bias", ("input_blocks.0.0.bias", "conv_in.bias"),
"time_embed.0.weight": "time_embedding.linear_1.weight", ("out.0.weight", "conv_norm_out.weight"),
"time_embed.0.bias": "time_embedding.linear_1.bias", ("out.0.bias", "conv_norm_out.bias"),
"time_embed.2.weight": "time_embedding.linear_2.weight", ("out.2.weight", "conv_out.weight"),
"time_embed.2.bias": "time_embedding.linear_2.bias" ("out.2.bias", "conv_out.bias"),
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias")
} }
def unet_to_diffusers(unet_config): def unet_to_diffusers(unet_config):
@ -208,7 +212,7 @@ def unet_to_diffusers(unet_config):
n += 1 n += 1
for k in UNET_MAP_BASIC: for k in UNET_MAP_BASIC:
diffusers_unet_map[UNET_MAP_BASIC[k]] = k diffusers_unet_map[k[1]] = k[0]
return diffusers_unet_map return diffusers_unet_map

View File

@ -1,9 +1,13 @@
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_base
import folder_paths import folder_paths
import json import json
import os import os
from comfy.cli_args import args
class ModelMergeSimple: class ModelMergeSimple:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -99,10 +103,36 @@ class CheckpointSave:
if prompt is not None: if prompt is not None:
prompt_info = json.dumps(prompt) prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info} metadata = {}
if extra_pnginfo is not None:
for x in extra_pnginfo: enable_modelspec = True
metadata[x] = json.dumps(extra_pnginfo[x]) if isinstance(model.model, comfy.model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False
if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)
#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"
if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

View File

@ -37,12 +37,23 @@ class ImageUpscaleWithModel:
device = model_management.get_torch_device() device = model_management.get_torch_device()
upscale_model.to(device) upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device) in_img = image.movedim(-1,-3).to(device)
free_memory = model_management.get_free_memory(device)
tile = 512
overlap = 32
oom = True
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
raise e
tile = 128 + 64
overlap = 8
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
upscale_model.cpu() upscale_model.cpu()
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,) return (s,)

81
cuda_malloc.py Normal file
View File

@ -0,0 +1,81 @@
import os
import importlib.util
from comfy.cli_args import args
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
if os.name == 'nt':
import ctypes
# Define necessary C structures and types
class DISPLAY_DEVICEA(ctypes.Structure):
_fields_ = [
('cb', ctypes.c_ulong),
('DeviceName', ctypes.c_char * 32),
('DeviceString', ctypes.c_char * 128),
('StateFlags', ctypes.c_ulong),
('DeviceID', ctypes.c_char * 128),
('DeviceKey', ctypes.c_char * 128)
]
# Load user32.dll
user32 = ctypes.windll.user32
# Call EnumDisplayDevicesA
def enum_display_devices():
device_info = DISPLAY_DEVICEA()
device_info.cb = ctypes.sizeof(device_info)
device_index = 0
gpu_names = set()
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
device_index += 1
gpu_names.add(device_info.DeviceString.decode('utf-8'))
return gpu_names
return enum_display_devices()
else:
return set()
def cuda_malloc_supported():
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000"}
try:
names = get_gpu_names()
except:
names = set()
for x in names:
if "NVIDIA" in x:
for b in blacklist:
if b in x:
return False
return True
if not args.cuda_malloc:
try:
version = ""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py")
if os.path.isfile(ver_file):
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
args.cuda_malloc = cuda_malloc_supported()
except:
pass
if args.cuda_malloc and not args.disable_cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var

View File

@ -1,6 +1,5 @@
import torch import torch
from PIL import Image, ImageOps from PIL import Image
from io import BytesIO
import struct import struct
import numpy as np import numpy as np
from comfy.cli_args import args, LatentPreviewMethod from comfy.cli_args import args, LatentPreviewMethod
@ -15,26 +14,7 @@ class LatentPreviewer:
def decode_latent_to_preview_image(self, preview_format, x0): def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0) preview_image = self.decode_latent_to_preview(x0)
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format, quality=95)
preview_bytes = bytesIO.getvalue()
return preview_bytes
class TAESDPreviewerImpl(LatentPreviewer): class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd): def __init__(self, taesd):

31
main.py
View File

@ -61,30 +61,7 @@ if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device) print("Set cuda device to:", args.cuda_device)
if not args.cuda_malloc: import cuda_malloc
try: #if there's a better way to check the torch version without importing it let me know
version = ""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py")
if os.path.isfile(ver_file):
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
args.cuda_malloc = True
except:
pass
if args.cuda_malloc and not args.disable_cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
import comfy.utils import comfy.utils
import yaml import yaml
@ -115,10 +92,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server): def hijack_progress(server):
def hook(value, total, preview_image_bytes): def hook(value, total, preview_image):
server.send_sync("progress", {"value": value, "max": total}, server.client_id) server.send_sync("progress", {"value": value, "max": total}, server.client_id)
if preview_image_bytes is not None: if preview_image is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)

100
nodes.py
View File

@ -26,6 +26,8 @@ import comfy.utils
import comfy.clip_vision import comfy.clip_vision
import comfy.model_management import comfy.model_management
from comfy.cli_args import args
import importlib import importlib
import folder_paths import folder_paths
@ -204,6 +206,28 @@ class ConditioningZeroOut:
c.append(n) c.append(n)
return (c, ) return (c, )
class ConditioningSetTimestepRange:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "set_range"
CATEGORY = "advanced/conditioning"
def set_range(self, conditioning, start, end):
c = []
for t in conditioning:
d = t[1].copy()
d['start_percent'] = 1.0 - start
d['end_percent'] = 1.0 - end
n = [t[0], d]
c.append(n)
return (c, )
class VAEDecode: class VAEDecode:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -330,10 +354,12 @@ class SaveLatent:
if prompt is not None: if prompt is not None:
prompt_info = json.dumps(prompt) prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info} metadata = None
if extra_pnginfo is not None: if not args.disable_metadata:
for x in extra_pnginfo: metadata = {"prompt": prompt_info}
metadata[x] = json.dumps(extra_pnginfo[x]) if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
file = f"{filename}_{counter:05}_.latent" file = f"{filename}_{counter:05}_.latent"
file = os.path.join(full_output_folder, file) file = os.path.join(full_output_folder, file)
@ -580,9 +606,58 @@ class ControlNetApply:
if 'control' in t[1]: if 'control' in t[1]:
c_net.set_previous_controlnet(t[1]['control']) c_net.set_previous_controlnet(t[1]['control'])
n[1]['control'] = c_net n[1]['control'] = c_net
n[1]['control_apply_to_uncond'] = True
c.append(n) c.append(n)
return (c, ) return (c, )
class ControlNetApplyAdvanced:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"image": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent):
if strength == 0:
return (positive, negative)
control_hint = image.movedim(-1,1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return (out[0], out[1])
class UNETLoader: class UNETLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1143,12 +1218,14 @@ class SaveImage:
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = PngInfo() metadata = None
if prompt is not None: if not args.disable_metadata:
metadata.add_text("prompt", json.dumps(prompt)) metadata = PngInfo()
if extra_pnginfo is not None: if prompt is not None:
for x in extra_pnginfo: metadata.add_text("prompt", json.dumps(prompt))
metadata.add_text(x, json.dumps(extra_pnginfo[x])) if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png" file = f"{filename}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
@ -1427,6 +1504,7 @@ NODE_CLASS_MAPPINGS = {
"StyleModelApply": StyleModelApply, "StyleModelApply": StyleModelApply,
"unCLIPConditioning": unCLIPConditioning, "unCLIPConditioning": unCLIPConditioning,
"ControlNetApply": ControlNetApply, "ControlNetApply": ControlNetApply,
"ControlNetApplyAdvanced": ControlNetApplyAdvanced,
"ControlNetLoader": ControlNetLoader, "ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader,
"StyleModelLoader": StyleModelLoader, "StyleModelLoader": StyleModelLoader,
@ -1444,6 +1522,7 @@ NODE_CLASS_MAPPINGS = {
"SaveLatent": SaveLatent, "SaveLatent": SaveLatent,
"ConditioningZeroOut": ConditioningZeroOut, "ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -1472,6 +1551,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetMask": "Conditioning (Set Mask)", "ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet", "ControlNetApply": "Apply ControlNet",
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
# Latent # Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask", "SetLatentNoiseMask": "Set Latent Noise Mask",

View File

@ -69,6 +69,13 @@
"source": [ "source": [
"# Checkpoints\n", "# Checkpoints\n",
"\n", "\n",
"### SDXL\n",
"### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n",
"\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n",
"\n",
"\n",
"# SD1.5\n", "# SD1.5\n",
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n", "!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
"\n", "\n",
@ -83,7 +90,7 @@
"#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n", "#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n",
"\n", "\n",
"# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n", "# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n",
"#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp16.safetensors -P ./models/checkpoints/\n", "#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n",
"\n", "\n",
"\n", "\n",
"# unCLIP models\n", "# unCLIP models\n",
@ -100,6 +107,7 @@
"# Loras\n", "# Loras\n",
"#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n", "#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n",
"#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n", "#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n",
"\n", "\n",
"\n", "\n",
"# T2I-Adapter\n", "# T2I-Adapter\n",

View File

@ -1,5 +1,4 @@
torch torch
torchdiffeq
torchsde torchsde
einops einops
transformers>=4.25.1 transformers>=4.25.1

View File

@ -2,8 +2,12 @@ import json
from urllib import request, parse from urllib import request, parse
import random import random
#this is the ComfyUI api prompt format. If you want it for a specific workflow you can copy it from the prompt section #This is the ComfyUI api prompt format.
#of the image metadata of images generated with ComfyUI
#If you want it for a specific workflow you can "enable dev mode options"
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
#a button on the UI to save workflows in api format.
#keep in mind ComfyUI is pre alpha software so this format will change a bit. #keep in mind ComfyUI is pre alpha software so this format will change a bit.
#this is the one for the default workflow #this is the one for the default workflow

View File

@ -8,7 +8,7 @@ import uuid
import json import json
import glob import glob
import struct import struct
from PIL import Image from PIL import Image, ImageOps
from io import BytesIO from io import BytesIO
try: try:
@ -29,6 +29,7 @@ import comfy.model_management
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
try: try:
@ -498,7 +499,9 @@ class PromptServer():
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
if isinstance(data, (bytes, bytearray)): if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)
elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid) await self.send_bytes(event, data, sid)
else: else:
await self.send_json(event, data, sid) await self.send_json(event, data, sid)
@ -512,6 +515,30 @@ class PromptServer():
message.extend(data) message.extend(data)
return message return message
async def send_image(self, image_data, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_bytes(self, event, data, sid=None): async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data) message = self.encode_bytes(event, data)

View File

@ -30,9 +30,7 @@ export interface ComfyExtension {
getCustomWidgets( getCustomWidgets(
app: ComfyApp app: ComfyApp
): Promise< ): Promise<
Array< Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
>
>; >;
/** /**
* Allows the extension to add additional handling to the node before it is registered with LGraph * Allows the extension to add additional handling to the node before it is registered with LGraph