mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-13 23:12:35 +08:00
Merge remote-tracking branch 'upstream/master' into node_expansion
This commit is contained in:
commit
95c8e22fae
@ -93,8 +93,8 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
|||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements:
|
This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements:
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6 -r requirements.txt```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,7 @@ parser.add_argument("--auto-launch", action="store_true", help="Automatically la
|
|||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
cm_group = parser.add_mutually_exclusive_group()
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Enable cudaMallocAsync.")
|
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||||
|
|
||||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
|
|
||||||
@ -84,6 +84,8 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin
|
|||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||||
|
|
||||||
|
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.windows_standalone_build:
|
if args.windows_standalone_build:
|
||||||
|
|||||||
@ -148,6 +148,10 @@ vae_conversion_map_attn = [
|
|||||||
("q.", "query."),
|
("q.", "query."),
|
||||||
("k.", "key."),
|
("k.", "key."),
|
||||||
("v.", "value."),
|
("v.", "value."),
|
||||||
|
("q.", "to_q."),
|
||||||
|
("k.", "to_k."),
|
||||||
|
("v.", "to_v."),
|
||||||
|
("proj_out.", "to_out.0."),
|
||||||
("proj_out.", "proj_attn."),
|
("proj_out.", "proj_attn."),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -91,7 +91,9 @@ class DiscreteSchedule(nn.Module):
|
|||||||
return log_sigma.exp()
|
return log_sigma.exp()
|
||||||
|
|
||||||
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
||||||
sigma = self.t_to_sigma(t.round())
|
if t.dtype != torch.int64 and t.dtype != torch.int32:
|
||||||
|
t = t.round()
|
||||||
|
sigma = self.t_to_sigma(t)
|
||||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import math
|
|||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchdiffeq import odeint
|
|
||||||
import torchsde
|
import torchsde
|
||||||
from tqdm.auto import trange, tqdm
|
from tqdm.auto import trange, tqdm
|
||||||
|
|
||||||
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
v = torch.randint_like(x, 2) * 2 - 1
|
|
||||||
fevals = 0
|
|
||||||
def ode_fn(sigma, x):
|
|
||||||
nonlocal fevals
|
|
||||||
with torch.enable_grad():
|
|
||||||
x = x[0].detach().requires_grad_()
|
|
||||||
denoised = model(x, sigma * s_in, **extra_args)
|
|
||||||
d = to_d(x, sigma, denoised)
|
|
||||||
fevals += 1
|
|
||||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
|
||||||
d_ll = (v * grad).flatten(1).sum(1)
|
|
||||||
return d.detach(), d_ll
|
|
||||||
x_min = x, x.new_zeros([x.shape[0]])
|
|
||||||
t = x.new_tensor([sigma_min, sigma_max])
|
|
||||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
|
||||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
|
||||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
|
||||||
return ll_prior + delta_ll, {'fevals': fevals}
|
|
||||||
|
|
||||||
|
|
||||||
class PIDStepSizeController:
|
class PIDStepSizeController:
|
||||||
"""A PID controller for ODE adaptive step size control."""
|
"""A PID controller for ODE adaptive step size control."""
|
||||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||||
|
|||||||
@ -164,7 +164,6 @@ class SDXLRefiner(BaseModel):
|
|||||||
else:
|
else:
|
||||||
aesthetic_score = kwargs.get("aesthetic_score", 6)
|
aesthetic_score = kwargs.get("aesthetic_score", 6)
|
||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
@ -188,7 +187,6 @@ class SDXL(BaseModel):
|
|||||||
target_width = kwargs.get("target_width", width)
|
target_width = kwargs.get("target_width", width)
|
||||||
target_height = kwargs.get("target_height", height)
|
target_height = kwargs.get("target_height", height)
|
||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
|
|||||||
@ -118,3 +118,57 @@ def model_config_from_unet_config(unet_config):
|
|||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||||
return model_config_from_unet_config(unet_config)
|
return model_config_from_unet_config(unet_config)
|
||||||
|
|
||||||
|
|
||||||
|
def model_config_from_diffusers_unet(state_dict, use_fp16):
|
||||||
|
match = {}
|
||||||
|
match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
||||||
|
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||||
|
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||||
|
match["adm_in_channels"] = None
|
||||||
|
if "class_embedding.linear_1.weight" in state_dict:
|
||||||
|
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
||||||
|
elif "add_embedding.linear_1.weight" in state_dict:
|
||||||
|
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||||
|
|
||||||
|
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||||
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
||||||
|
|
||||||
|
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
||||||
|
|
||||||
|
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
||||||
|
|
||||||
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
||||||
|
|
||||||
|
for unet_config in supported_models:
|
||||||
|
matches = True
|
||||||
|
for k in match:
|
||||||
|
if match[k] != unet_config[k]:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
if matches:
|
||||||
|
return model_config_from_unet_config(unet_config)
|
||||||
|
return None
|
||||||
|
|||||||
@ -49,6 +49,7 @@ except:
|
|||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
cpu_state = CPUState.MPS
|
cpu_state = CPUState.MPS
|
||||||
|
import torch.mps
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -280,19 +281,23 @@ def load_model_gpu(model):
|
|||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
|
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
patch_model_to = None
|
||||||
if vram_set_state == VRAMState.DISABLED:
|
if vram_set_state == VRAMState.DISABLED:
|
||||||
pass
|
pass
|
||||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.to(torch_dev)
|
patch_model_to = torch_dev
|
||||||
|
|
||||||
try:
|
try:
|
||||||
real_model = model.patch_model()
|
real_model = model.patch_model(device_to=patch_model_to)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
unload_model()
|
unload_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
if patch_model_to is not None:
|
||||||
|
real_model.to(torch_dev)
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||||
@ -529,7 +534,7 @@ def should_use_fp16(device=None, model_params=0):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
#FP16 is just broken on these cards
|
#FP16 is just broken on these cards
|
||||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
|
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450"]
|
||||||
for x in nvidia_16_series:
|
for x in nvidia_16_series:
|
||||||
if x in props.name:
|
if x in props.name:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -17,6 +17,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
|
if 'timestep_start' in cond[1]:
|
||||||
|
timestep_start = cond[1]['timestep_start']
|
||||||
|
if timestep_in[0] > timestep_start:
|
||||||
|
return None
|
||||||
|
if 'timestep_end' in cond[1]:
|
||||||
|
timestep_end = cond[1]['timestep_end']
|
||||||
|
if timestep_in[0] < timestep_end:
|
||||||
|
return None
|
||||||
if 'area' in cond[1]:
|
if 'area' in cond[1]:
|
||||||
area = cond[1]['area']
|
area = cond[1]['area']
|
||||||
if 'strength' in cond[1]:
|
if 'strength' in cond[1]:
|
||||||
@ -248,7 +256,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
if 'model_function_wrapper' in model_options:
|
||||||
|
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
|
else:
|
||||||
|
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
@ -425,6 +436,35 @@ def create_cond_with_same_area_if_none(conds, c):
|
|||||||
n = c[1].copy()
|
n = c[1].copy()
|
||||||
conds += [[smallest[0], n]]
|
conds += [[smallest[0], n]]
|
||||||
|
|
||||||
|
def calculate_start_end_timesteps(model, conds):
|
||||||
|
for t in range(len(conds)):
|
||||||
|
x = conds[t]
|
||||||
|
|
||||||
|
timestep_start = None
|
||||||
|
timestep_end = None
|
||||||
|
if 'start_percent' in x[1]:
|
||||||
|
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
|
||||||
|
if 'end_percent' in x[1]:
|
||||||
|
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
|
||||||
|
|
||||||
|
if (timestep_start is not None) or (timestep_end is not None):
|
||||||
|
n = x[1].copy()
|
||||||
|
if (timestep_start is not None):
|
||||||
|
n['timestep_start'] = timestep_start
|
||||||
|
if (timestep_end is not None):
|
||||||
|
n['timestep_end'] = timestep_end
|
||||||
|
conds[t] = [x[0], n]
|
||||||
|
|
||||||
|
def pre_run_control(model, conds):
|
||||||
|
for t in range(len(conds)):
|
||||||
|
x = conds[t]
|
||||||
|
|
||||||
|
timestep_start = None
|
||||||
|
timestep_end = None
|
||||||
|
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
|
||||||
|
if 'control' in x[1]:
|
||||||
|
x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function)
|
||||||
|
|
||||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
cond_other = []
|
cond_other = []
|
||||||
@ -568,13 +608,18 @@ class KSampler:
|
|||||||
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
|
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
|
||||||
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
|
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
|
||||||
|
|
||||||
|
calculate_start_end_timesteps(self.model_wrap, negative)
|
||||||
|
calculate_start_end_timesteps(self.model_wrap, positive)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for c in positive:
|
||||||
create_cond_with_same_area_if_none(negative, c)
|
create_cond_with_same_area_if_none(negative, c)
|
||||||
for c in negative:
|
for c in negative:
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
create_cond_with_same_area_if_none(positive, c)
|
||||||
|
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
pre_run_control(self.model_wrap, negative + positive)
|
||||||
|
|
||||||
|
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
if self.model.is_adm():
|
if self.model.is_adm():
|
||||||
|
|||||||
289
comfy/sd.py
289
comfy/sd.py
@ -170,6 +170,8 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
@ -202,6 +204,14 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
def set_attr(obj, attr, value):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs[:-1]:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
prev = getattr(obj, attrs[-1])
|
||||||
|
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||||
|
del prev
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0):
|
def __init__(self, model, load_device, offload_device, size=0):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -330,7 +340,7 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_model(self):
|
def patch_model(self, device_to=None):
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
for key in self.patches:
|
for key in self.patches:
|
||||||
if key not in model_sd:
|
if key not in model_sd:
|
||||||
@ -340,10 +350,14 @@ class ModelPatcher:
|
|||||||
weight = model_sd[key]
|
weight = model_sd[key]
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(self.offload_device, copy=True)
|
self.backup[key] = weight.to(self.offload_device)
|
||||||
|
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
if device_to is not None:
|
||||||
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
temp_weight = weight.float().to(device_to, copy=True)
|
||||||
|
else:
|
||||||
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
set_attr(self.model, key, out_weight)
|
||||||
del temp_weight
|
del temp_weight
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -376,7 +390,10 @@ class ModelPatcher:
|
|||||||
mat3 = v[3].float().to(weight.device)
|
mat3 = v[3].float().to(weight.device)
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
try:
|
||||||
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
elif len(v) == 8: #lokr
|
elif len(v) == 8: #lokr
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
w2 = v[1]
|
w2 = v[1]
|
||||||
@ -407,7 +424,10 @@ class ModelPatcher:
|
|||||||
if v[2] is not None and dim is not None:
|
if v[2] is not None and dim is not None:
|
||||||
alpha *= v[2] / dim
|
alpha *= v[2] / dim
|
||||||
|
|
||||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
try:
|
||||||
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
else: #loha
|
else: #loha
|
||||||
w1a = v[0]
|
w1a = v[0]
|
||||||
w1b = v[1]
|
w1b = v[1]
|
||||||
@ -424,18 +444,15 @@ class ModelPatcher:
|
|||||||
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
||||||
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
||||||
|
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
try:
|
||||||
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
def set_attr(obj, attr, value):
|
|
||||||
attrs = attr.split(".")
|
|
||||||
for name in attrs[:-1]:
|
|
||||||
obj = getattr(obj, name)
|
|
||||||
prev = getattr(obj, attrs[-1])
|
|
||||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
|
||||||
del prev
|
|
||||||
|
|
||||||
for k in keys:
|
for k in keys:
|
||||||
set_attr(self.model, k, self.backup[k])
|
set_attr(self.model, k, self.backup[k])
|
||||||
@ -658,16 +675,57 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|||||||
else:
|
else:
|
||||||
return torch.cat([tensor] * batched_number, dim=0)
|
return torch.cat([tensor] * batched_number, dim=0)
|
||||||
|
|
||||||
class ControlNet:
|
class ControlBase:
|
||||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
def __init__(self, device=None):
|
||||||
self.control_model = control_model
|
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
|
self.timestep_percent_range = (1.0, 0.0)
|
||||||
|
self.timestep_range = None
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
|
||||||
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
||||||
|
self.cond_hint_original = cond_hint
|
||||||
|
self.strength = strength
|
||||||
|
self.timestep_percent_range = timestep_percent_range
|
||||||
|
return self
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
self.previous_controlnet.pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
|
def set_previous_controlnet(self, controlnet):
|
||||||
|
self.previous_controlnet = controlnet
|
||||||
|
return self
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
self.previous_controlnet.cleanup()
|
||||||
|
if self.cond_hint is not None:
|
||||||
|
del self.cond_hint
|
||||||
|
self.cond_hint = None
|
||||||
|
self.timestep_range = None
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
out = []
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
out += self.previous_controlnet.get_models()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def copy_to(self, c):
|
||||||
|
c.cond_hint_original = self.cond_hint_original
|
||||||
|
c.strength = self.strength
|
||||||
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
|
|
||||||
|
class ControlNet(ControlBase):
|
||||||
|
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||||
|
super().__init__(device)
|
||||||
|
self.control_model = control_model
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
@ -675,6 +733,13 @@ class ControlNet:
|
|||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||||
|
|
||||||
|
if self.timestep_range is not None:
|
||||||
|
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||||
|
if control_prev is not None:
|
||||||
|
return control_prev
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
output_dtype = x_noisy.dtype
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
@ -722,37 +787,64 @@ class ControlNet:
|
|||||||
out['input'] = control_prev['input']
|
out['input'] = control_prev['input']
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0):
|
|
||||||
self.cond_hint_original = cond_hint
|
|
||||||
self.strength = strength
|
|
||||||
return self
|
|
||||||
|
|
||||||
def set_previous_controlnet(self, controlnet):
|
|
||||||
self.previous_controlnet = controlnet
|
|
||||||
return self
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
self.previous_controlnet.cleanup()
|
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
|
||||||
self.cond_hint = None
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||||
c.cond_hint_original = self.cond_hint_original
|
self.copy_to(c)
|
||||||
c.strength = self.strength
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
out = []
|
out = super().get_models()
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
out += self.previous_controlnet.get_models()
|
|
||||||
out.append(self.control_model)
|
out.append(self.control_model)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
|
|
||||||
|
controlnet_config = None
|
||||||
|
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||||
|
use_fp16 = model_management.should_use_fp16()
|
||||||
|
controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config
|
||||||
|
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
|
||||||
|
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||||
|
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
loop = True
|
||||||
|
while loop:
|
||||||
|
suffix = [".weight", ".bias"]
|
||||||
|
for s in suffix:
|
||||||
|
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
||||||
|
k_out = "zero_convs.{}.0{}".format(count, s)
|
||||||
|
if k_in not in controlnet_data:
|
||||||
|
loop = False
|
||||||
|
break
|
||||||
|
diffusers_keys[k_in] = k_out
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
loop = True
|
||||||
|
while loop:
|
||||||
|
suffix = [".weight", ".bias"]
|
||||||
|
for s in suffix:
|
||||||
|
if count == 0:
|
||||||
|
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
||||||
|
else:
|
||||||
|
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
||||||
|
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
||||||
|
if k_in not in controlnet_data:
|
||||||
|
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
||||||
|
loop = False
|
||||||
|
diffusers_keys[k_in] = k_out
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
new_sd = {}
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k in controlnet_data:
|
||||||
|
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||||
|
|
||||||
|
controlnet_data = new_sd
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
key = 'zero_convs.0.0.weight'
|
key = 'zero_convs.0.0.weight'
|
||||||
@ -768,9 +860,9 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
||||||
return net
|
return net
|
||||||
|
|
||||||
use_fp16 = model_management.should_use_fp16()
|
if controlnet_config is None:
|
||||||
|
use_fp16 = model_management.should_use_fp16()
|
||||||
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
|
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = 3
|
controlnet_config["hint_channels"] = 3
|
||||||
control_model = cldm.ControlNet(**controlnet_config)
|
control_model = cldm.ControlNet(**controlnet_config)
|
||||||
@ -810,24 +902,25 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
class T2IAdapter:
|
class T2IAdapter(ControlBase):
|
||||||
def __init__(self, t2i_model, channels_in, device=None):
|
def __init__(self, t2i_model, channels_in, device=None):
|
||||||
|
super().__init__(device)
|
||||||
self.t2i_model = t2i_model
|
self.t2i_model = t2i_model
|
||||||
self.channels_in = channels_in
|
self.channels_in = channels_in
|
||||||
self.strength = 1.0
|
|
||||||
if device is None:
|
|
||||||
device = model_management.get_torch_device()
|
|
||||||
self.device = device
|
|
||||||
self.previous_controlnet = None
|
|
||||||
self.control_input = None
|
self.control_input = None
|
||||||
self.cond_hint_original = None
|
|
||||||
self.cond_hint = None
|
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
||||||
|
|
||||||
|
if self.timestep_range is not None:
|
||||||
|
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||||
|
if control_prev is not None:
|
||||||
|
return control_prev
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
@ -872,33 +965,11 @@ class T2IAdapter:
|
|||||||
out['output'] = control_prev['output']
|
out['output'] = control_prev['output']
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0):
|
|
||||||
self.cond_hint_original = cond_hint
|
|
||||||
self.strength = strength
|
|
||||||
return self
|
|
||||||
|
|
||||||
def set_previous_controlnet(self, controlnet):
|
|
||||||
self.previous_controlnet = controlnet
|
|
||||||
return self
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = T2IAdapter(self.t2i_model, self.channels_in)
|
c = T2IAdapter(self.t2i_model, self.channels_in)
|
||||||
c.cond_hint_original = self.cond_hint_original
|
self.copy_to(c)
|
||||||
c.strength = self.strength
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
self.previous_controlnet.cleanup()
|
|
||||||
if self.cond_hint is not None:
|
|
||||||
del self.cond_hint
|
|
||||||
self.cond_hint = None
|
|
||||||
|
|
||||||
def get_models(self):
|
|
||||||
out = []
|
|
||||||
if self.previous_controlnet is not None:
|
|
||||||
out += self.previous_controlnet.get_models()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data):
|
def load_t2i_adapter(t2i_data):
|
||||||
keys = t2i_data.keys()
|
keys = t2i_data.keys()
|
||||||
@ -1128,66 +1199,24 @@ def load_unet(unet_path): #load unet in diffusers format
|
|||||||
parameters = calculate_parameters(sd, "")
|
parameters = calculate_parameters(sd, "")
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||||
|
|
||||||
match = {}
|
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
||||||
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
if model_config is None:
|
||||||
match["model_channels"] = sd["conv_in.weight"].shape[0]
|
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||||
match["in_channels"] = sd["conv_in.weight"].shape[1]
|
return None
|
||||||
match["adm_in_channels"] = None
|
|
||||||
if "class_embedding.linear_1.weight" in sd:
|
|
||||||
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
|
|
||||||
|
|
||||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
diffusers_keys = utils.unet_to_diffusers(model_config.unet_config)
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
|
||||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
|
||||||
|
|
||||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
new_sd = {}
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384,
|
for k in diffusers_keys:
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
if k in sd:
|
||||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||||
|
else:
|
||||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
print(diffusers_keys[k], k)
|
||||||
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
offload_device = model_management.unet_offload_device()
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
model = model_config.get_model(new_sd, "")
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
model = model.to(offload_device)
|
||||||
|
model.load_model_weights(new_sd, "")
|
||||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
|
||||||
|
|
||||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
|
||||||
|
|
||||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
|
||||||
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
|
||||||
print("match", match)
|
|
||||||
for unet_config in supported_models:
|
|
||||||
matches = True
|
|
||||||
for k in match:
|
|
||||||
if match[k] != unet_config[k]:
|
|
||||||
matches = False
|
|
||||||
break
|
|
||||||
if matches:
|
|
||||||
diffusers_keys = utils.unet_to_diffusers(unet_config)
|
|
||||||
new_sd = {}
|
|
||||||
for k in diffusers_keys:
|
|
||||||
if k in sd:
|
|
||||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
|
||||||
else:
|
|
||||||
print(diffusers_keys[k], k)
|
|
||||||
offload_device = model_management.unet_offload_device()
|
|
||||||
model_config = model_detection.model_config_from_unet_config(unet_config)
|
|
||||||
model = model_config.get_model(new_sd, "")
|
|
||||||
model = model.to(offload_device)
|
|
||||||
model.load_model_weights(new_sd, "")
|
|
||||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -126,7 +126,8 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
||||||
|
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
||||||
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||||
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
return state_dict_g
|
return state_dict_g
|
||||||
@ -171,7 +172,8 @@ class SDXL(supported_models_base.BASE):
|
|||||||
replace_prefix = {}
|
replace_prefix = {}
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
||||||
|
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
if k.startswith("clip_l"):
|
if k.startswith("clip_l"):
|
||||||
state_dict_g[k] = state_dict[k]
|
state_dict_g[k] = state_dict[k]
|
||||||
|
|||||||
@ -120,20 +120,24 @@ UNET_MAP_RESNET = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
UNET_MAP_BASIC = {
|
UNET_MAP_BASIC = {
|
||||||
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
|
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
||||||
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
|
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
||||||
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
|
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
||||||
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
|
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
||||||
"input_blocks.0.0.weight": "conv_in.weight",
|
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||||
"input_blocks.0.0.bias": "conv_in.bias",
|
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||||
"out.0.weight": "conv_norm_out.weight",
|
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||||
"out.0.bias": "conv_norm_out.bias",
|
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||||
"out.2.weight": "conv_out.weight",
|
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||||
"out.2.bias": "conv_out.bias",
|
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||||
"time_embed.0.weight": "time_embedding.linear_1.weight",
|
("out.0.weight", "conv_norm_out.weight"),
|
||||||
"time_embed.0.bias": "time_embedding.linear_1.bias",
|
("out.0.bias", "conv_norm_out.bias"),
|
||||||
"time_embed.2.weight": "time_embedding.linear_2.weight",
|
("out.2.weight", "conv_out.weight"),
|
||||||
"time_embed.2.bias": "time_embedding.linear_2.bias"
|
("out.2.bias", "conv_out.bias"),
|
||||||
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||||
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||||
|
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||||
|
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
||||||
}
|
}
|
||||||
|
|
||||||
def unet_to_diffusers(unet_config):
|
def unet_to_diffusers(unet_config):
|
||||||
@ -208,7 +212,7 @@ def unet_to_diffusers(unet_config):
|
|||||||
n += 1
|
n += 1
|
||||||
|
|
||||||
for k in UNET_MAP_BASIC:
|
for k in UNET_MAP_BASIC:
|
||||||
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
|
diffusers_unet_map[k[1]] = k[0]
|
||||||
|
|
||||||
return diffusers_unet_map
|
return diffusers_unet_map
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,13 @@
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_base
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
class ModelMergeSimple:
|
class ModelMergeSimple:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -99,10 +103,36 @@ class CheckpointSave:
|
|||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompt_info = json.dumps(prompt)
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
metadata = {"prompt": prompt_info}
|
metadata = {}
|
||||||
if extra_pnginfo is not None:
|
|
||||||
for x in extra_pnginfo:
|
enable_modelspec = True
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
if isinstance(model.model, comfy.model_base.SDXL):
|
||||||
|
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
|
||||||
|
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
|
||||||
|
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
|
||||||
|
else:
|
||||||
|
enable_modelspec = False
|
||||||
|
|
||||||
|
if enable_modelspec:
|
||||||
|
metadata["modelspec.sai_model_spec"] = "1.0.0"
|
||||||
|
metadata["modelspec.implementation"] = "sgm"
|
||||||
|
metadata["modelspec.title"] = "{} {}".format(filename, counter)
|
||||||
|
|
||||||
|
#TODO:
|
||||||
|
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
|
||||||
|
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
|
||||||
|
# "v2-inpainting"
|
||||||
|
|
||||||
|
if model.model.model_type == comfy.model_base.ModelType.EPS:
|
||||||
|
metadata["modelspec.predict_key"] = "epsilon"
|
||||||
|
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||||
|
metadata["modelspec.predict_key"] = "v"
|
||||||
|
|
||||||
|
if not args.disable_metadata:
|
||||||
|
metadata["prompt"] = prompt_info
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|||||||
@ -37,12 +37,23 @@ class ImageUpscaleWithModel:
|
|||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
upscale_model.to(device)
|
upscale_model.to(device)
|
||||||
in_img = image.movedim(-1,-3).to(device)
|
in_img = image.movedim(-1,-3).to(device)
|
||||||
|
free_memory = model_management.get_free_memory(device)
|
||||||
|
|
||||||
|
tile = 512
|
||||||
|
overlap = 32
|
||||||
|
|
||||||
|
oom = True
|
||||||
|
while oom:
|
||||||
|
try:
|
||||||
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||||
|
oom = False
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
tile //= 2
|
||||||
|
if tile < 128:
|
||||||
|
raise e
|
||||||
|
|
||||||
tile = 128 + 64
|
|
||||||
overlap = 8
|
|
||||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
|
||||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
|
||||||
upscale_model.cpu()
|
upscale_model.cpu()
|
||||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|||||||
81
cuda_malloc.py
Normal file
81
cuda_malloc.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import os
|
||||||
|
import importlib.util
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||||
|
def get_gpu_names():
|
||||||
|
if os.name == 'nt':
|
||||||
|
import ctypes
|
||||||
|
|
||||||
|
# Define necessary C structures and types
|
||||||
|
class DISPLAY_DEVICEA(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
('cb', ctypes.c_ulong),
|
||||||
|
('DeviceName', ctypes.c_char * 32),
|
||||||
|
('DeviceString', ctypes.c_char * 128),
|
||||||
|
('StateFlags', ctypes.c_ulong),
|
||||||
|
('DeviceID', ctypes.c_char * 128),
|
||||||
|
('DeviceKey', ctypes.c_char * 128)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Load user32.dll
|
||||||
|
user32 = ctypes.windll.user32
|
||||||
|
|
||||||
|
# Call EnumDisplayDevicesA
|
||||||
|
def enum_display_devices():
|
||||||
|
device_info = DISPLAY_DEVICEA()
|
||||||
|
device_info.cb = ctypes.sizeof(device_info)
|
||||||
|
device_index = 0
|
||||||
|
gpu_names = set()
|
||||||
|
|
||||||
|
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
|
||||||
|
device_index += 1
|
||||||
|
gpu_names.add(device_info.DeviceString.decode('utf-8'))
|
||||||
|
return gpu_names
|
||||||
|
return enum_display_devices()
|
||||||
|
else:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
def cuda_malloc_supported():
|
||||||
|
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
|
||||||
|
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
|
||||||
|
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
|
||||||
|
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
names = get_gpu_names()
|
||||||
|
except:
|
||||||
|
names = set()
|
||||||
|
for x in names:
|
||||||
|
if "NVIDIA" in x:
|
||||||
|
for b in blacklist:
|
||||||
|
if b in x:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
if not args.cuda_malloc:
|
||||||
|
try:
|
||||||
|
version = ""
|
||||||
|
torch_spec = importlib.util.find_spec("torch")
|
||||||
|
for folder in torch_spec.submodule_search_locations:
|
||||||
|
ver_file = os.path.join(folder, "version.py")
|
||||||
|
if os.path.isfile(ver_file):
|
||||||
|
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
version = module.__version__
|
||||||
|
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
|
||||||
|
args.cuda_malloc = cuda_malloc_supported()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if args.cuda_malloc and not args.disable_cuda_malloc:
|
||||||
|
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||||
|
if env_var is None:
|
||||||
|
env_var = "backend:cudaMallocAsync"
|
||||||
|
else:
|
||||||
|
env_var += ",backend:cudaMallocAsync"
|
||||||
|
|
||||||
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||||
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image
|
||||||
from io import BytesIO
|
|
||||||
import struct
|
import struct
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from comfy.cli_args import args, LatentPreviewMethod
|
from comfy.cli_args import args, LatentPreviewMethod
|
||||||
@ -15,26 +14,7 @@ class LatentPreviewer:
|
|||||||
|
|
||||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||||
preview_image = self.decode_latent_to_preview(x0)
|
preview_image = self.decode_latent_to_preview(x0)
|
||||||
|
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
|
||||||
if hasattr(Image, 'Resampling'):
|
|
||||||
resampling = Image.Resampling.BILINEAR
|
|
||||||
else:
|
|
||||||
resampling = Image.ANTIALIAS
|
|
||||||
|
|
||||||
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling)
|
|
||||||
|
|
||||||
preview_type = 1
|
|
||||||
if preview_format == "JPEG":
|
|
||||||
preview_type = 1
|
|
||||||
elif preview_format == "PNG":
|
|
||||||
preview_type = 2
|
|
||||||
|
|
||||||
bytesIO = BytesIO()
|
|
||||||
header = struct.pack(">I", preview_type)
|
|
||||||
bytesIO.write(header)
|
|
||||||
preview_image.save(bytesIO, format=preview_format, quality=95)
|
|
||||||
preview_bytes = bytesIO.getvalue()
|
|
||||||
return preview_bytes
|
|
||||||
|
|
||||||
class TAESDPreviewerImpl(LatentPreviewer):
|
class TAESDPreviewerImpl(LatentPreviewer):
|
||||||
def __init__(self, taesd):
|
def __init__(self, taesd):
|
||||||
|
|||||||
31
main.py
31
main.py
@ -61,30 +61,7 @@ if __name__ == "__main__":
|
|||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
print("Set cuda device to:", args.cuda_device)
|
print("Set cuda device to:", args.cuda_device)
|
||||||
|
|
||||||
if not args.cuda_malloc:
|
import cuda_malloc
|
||||||
try: #if there's a better way to check the torch version without importing it let me know
|
|
||||||
version = ""
|
|
||||||
torch_spec = importlib.util.find_spec("torch")
|
|
||||||
for folder in torch_spec.submodule_search_locations:
|
|
||||||
ver_file = os.path.join(folder, "version.py")
|
|
||||||
if os.path.isfile(ver_file):
|
|
||||||
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
version = module.__version__
|
|
||||||
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
|
|
||||||
args.cuda_malloc = True
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if args.cuda_malloc and not args.disable_cuda_malloc:
|
|
||||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
|
||||||
if env_var is None:
|
|
||||||
env_var = "backend:cudaMallocAsync"
|
|
||||||
else:
|
|
||||||
env_var += ",backend:cudaMallocAsync"
|
|
||||||
|
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import yaml
|
import yaml
|
||||||
@ -115,10 +92,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
|||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
def hook(value, total, preview_image_bytes):
|
def hook(value, total, preview_image):
|
||||||
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
|
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
|
||||||
if preview_image_bytes is not None:
|
if preview_image is not None:
|
||||||
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
|
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||||
comfy.utils.set_progress_bar_global_hook(hook)
|
comfy.utils.set_progress_bar_global_hook(hook)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
100
nodes.py
100
nodes.py
@ -26,6 +26,8 @@ import comfy.utils
|
|||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -204,6 +206,28 @@ class ConditioningZeroOut:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class ConditioningSetTimestepRange:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||||
|
"start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "set_range"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def set_range(self, conditioning, start, end):
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
d['start_percent'] = 1.0 - start
|
||||||
|
d['end_percent'] = 1.0 - end
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -330,10 +354,12 @@ class SaveLatent:
|
|||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompt_info = json.dumps(prompt)
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
metadata = {"prompt": prompt_info}
|
metadata = None
|
||||||
if extra_pnginfo is not None:
|
if not args.disable_metadata:
|
||||||
for x in extra_pnginfo:
|
metadata = {"prompt": prompt_info}
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
file = f"{filename}_{counter:05}_.latent"
|
file = f"{filename}_{counter:05}_.latent"
|
||||||
file = os.path.join(full_output_folder, file)
|
file = os.path.join(full_output_folder, file)
|
||||||
@ -580,9 +606,58 @@ class ControlNetApply:
|
|||||||
if 'control' in t[1]:
|
if 'control' in t[1]:
|
||||||
c_net.set_previous_controlnet(t[1]['control'])
|
c_net.set_previous_controlnet(t[1]['control'])
|
||||||
n[1]['control'] = c_net
|
n[1]['control'] = c_net
|
||||||
|
n[1]['control_apply_to_uncond'] = True
|
||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetApplyAdvanced:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"control_net": ("CONTROL_NET", ),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||||
|
RETURN_NAMES = ("positive", "negative")
|
||||||
|
FUNCTION = "apply_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent):
|
||||||
|
if strength == 0:
|
||||||
|
return (positive, negative)
|
||||||
|
|
||||||
|
control_hint = image.movedim(-1,1)
|
||||||
|
cnets = {}
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for conditioning in [positive, negative]:
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
|
||||||
|
prev_cnet = d.get('control', None)
|
||||||
|
if prev_cnet in cnets:
|
||||||
|
c_net = cnets[prev_cnet]
|
||||||
|
else:
|
||||||
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
|
||||||
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
|
d['control'] = c_net
|
||||||
|
d['control_apply_to_uncond'] = False
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
out.append(c)
|
||||||
|
return (out[0], out[1])
|
||||||
|
|
||||||
|
|
||||||
class UNETLoader:
|
class UNETLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1143,12 +1218,14 @@ class SaveImage:
|
|||||||
for image in images:
|
for image in images:
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.cpu().numpy()
|
||||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||||
metadata = PngInfo()
|
metadata = None
|
||||||
if prompt is not None:
|
if not args.disable_metadata:
|
||||||
metadata.add_text("prompt", json.dumps(prompt))
|
metadata = PngInfo()
|
||||||
if extra_pnginfo is not None:
|
if prompt is not None:
|
||||||
for x in extra_pnginfo:
|
metadata.add_text("prompt", json.dumps(prompt))
|
||||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||||
|
|
||||||
file = f"{filename}_{counter:05}_.png"
|
file = f"{filename}_{counter:05}_.png"
|
||||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
|
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
|
||||||
@ -1427,6 +1504,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"StyleModelApply": StyleModelApply,
|
"StyleModelApply": StyleModelApply,
|
||||||
"unCLIPConditioning": unCLIPConditioning,
|
"unCLIPConditioning": unCLIPConditioning,
|
||||||
"ControlNetApply": ControlNetApply,
|
"ControlNetApply": ControlNetApply,
|
||||||
|
"ControlNetApplyAdvanced": ControlNetApplyAdvanced,
|
||||||
"ControlNetLoader": ControlNetLoader,
|
"ControlNetLoader": ControlNetLoader,
|
||||||
"DiffControlNetLoader": DiffControlNetLoader,
|
"DiffControlNetLoader": DiffControlNetLoader,
|
||||||
"StyleModelLoader": StyleModelLoader,
|
"StyleModelLoader": StyleModelLoader,
|
||||||
@ -1444,6 +1522,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SaveLatent": SaveLatent,
|
"SaveLatent": SaveLatent,
|
||||||
|
|
||||||
"ConditioningZeroOut": ConditioningZeroOut,
|
"ConditioningZeroOut": ConditioningZeroOut,
|
||||||
|
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -1472,6 +1551,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||||
"ControlNetApply": "Apply ControlNet",
|
"ControlNetApply": "Apply ControlNet",
|
||||||
|
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
|
||||||
# Latent
|
# Latent
|
||||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||||
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
"SetLatentNoiseMask": "Set Latent Noise Mask",
|
||||||
|
|||||||
@ -69,6 +69,13 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Checkpoints\n",
|
"# Checkpoints\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"### SDXL\n",
|
||||||
|
"### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n",
|
||||||
|
"\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"# SD1.5\n",
|
"# SD1.5\n",
|
||||||
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -83,7 +90,7 @@
|
|||||||
"#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n",
|
"# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n",
|
||||||
"#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp16.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# unCLIP models\n",
|
"# unCLIP models\n",
|
||||||
@ -100,6 +107,7 @@
|
|||||||
"# Loras\n",
|
"# Loras\n",
|
||||||
"#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n",
|
"#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n",
|
||||||
"#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n",
|
"#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# T2I-Adapter\n",
|
"# T2I-Adapter\n",
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
torch
|
torch
|
||||||
torchdiffeq
|
|
||||||
torchsde
|
torchsde
|
||||||
einops
|
einops
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
|
|||||||
@ -2,8 +2,12 @@ import json
|
|||||||
from urllib import request, parse
|
from urllib import request, parse
|
||||||
import random
|
import random
|
||||||
|
|
||||||
#this is the ComfyUI api prompt format. If you want it for a specific workflow you can copy it from the prompt section
|
#This is the ComfyUI api prompt format.
|
||||||
#of the image metadata of images generated with ComfyUI
|
|
||||||
|
#If you want it for a specific workflow you can "enable dev mode options"
|
||||||
|
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
|
||||||
|
#a button on the UI to save workflows in api format.
|
||||||
|
|
||||||
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
|
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
|
||||||
|
|
||||||
#this is the one for the default workflow
|
#this is the one for the default workflow
|
||||||
|
|||||||
31
server.py
31
server.py
@ -8,7 +8,7 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -29,6 +29,7 @@ import comfy.model_management
|
|||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
@ -498,7 +499,9 @@ class PromptServer():
|
|||||||
return prompt_info
|
return prompt_info
|
||||||
|
|
||||||
async def send(self, event, data, sid=None):
|
async def send(self, event, data, sid=None):
|
||||||
if isinstance(data, (bytes, bytearray)):
|
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||||
|
await self.send_image(data, sid=sid)
|
||||||
|
elif isinstance(data, (bytes, bytearray)):
|
||||||
await self.send_bytes(event, data, sid)
|
await self.send_bytes(event, data, sid)
|
||||||
else:
|
else:
|
||||||
await self.send_json(event, data, sid)
|
await self.send_json(event, data, sid)
|
||||||
@ -512,6 +515,30 @@ class PromptServer():
|
|||||||
message.extend(data)
|
message.extend(data)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
async def send_image(self, image_data, sid=None):
|
||||||
|
image_type = image_data[0]
|
||||||
|
image = image_data[1]
|
||||||
|
max_size = image_data[2]
|
||||||
|
if max_size is not None:
|
||||||
|
if hasattr(Image, 'Resampling'):
|
||||||
|
resampling = Image.Resampling.BILINEAR
|
||||||
|
else:
|
||||||
|
resampling = Image.ANTIALIAS
|
||||||
|
|
||||||
|
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||||
|
type_num = 1
|
||||||
|
if image_type == "JPEG":
|
||||||
|
type_num = 1
|
||||||
|
elif image_type == "PNG":
|
||||||
|
type_num = 2
|
||||||
|
|
||||||
|
bytesIO = BytesIO()
|
||||||
|
header = struct.pack(">I", type_num)
|
||||||
|
bytesIO.write(header)
|
||||||
|
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
|
||||||
|
preview_bytes = bytesIO.getvalue()
|
||||||
|
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||||
|
|
||||||
async def send_bytes(self, event, data, sid=None):
|
async def send_bytes(self, event, data, sid=None):
|
||||||
message = self.encode_bytes(event, data)
|
message = self.encode_bytes(event, data)
|
||||||
|
|
||||||
|
|||||||
4
web/types/comfy.d.ts
vendored
4
web/types/comfy.d.ts
vendored
@ -30,9 +30,7 @@ export interface ComfyExtension {
|
|||||||
getCustomWidgets(
|
getCustomWidgets(
|
||||||
app: ComfyApp
|
app: ComfyApp
|
||||||
): Promise<
|
): Promise<
|
||||||
Array<
|
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
|
||||||
Record<string, (node, inputName, inputData, app) => { widget?: IWidget; minWidth?: number; minHeight?: number }>
|
|
||||||
>
|
|
||||||
>;
|
>;
|
||||||
/**
|
/**
|
||||||
* Allows the extension to add additional handling to the node before it is registered with LGraph
|
* Allows the extension to add additional handling to the node before it is registered with LGraph
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user