Merge branch 'comfyanonymous:master' into feature/maskpainting

This commit is contained in:
ltdrdata 2023-05-02 11:28:48 +09:00 committed by GitHub
commit 8c8bedbaf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 170 additions and 22 deletions

View File

@ -712,7 +712,7 @@ class UniPC:
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform', def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, callback=None atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
): ):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start t_T = self.noise_schedule.T if t_start is None else t_start
@ -723,7 +723,7 @@ class UniPC:
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
# with torch.no_grad(): # with torch.no_grad():
for step_index in trange(steps): for step_index in trange(steps, disable=disable_pbar):
if self.noise_mask is not None: if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0: if step_index == 0:
@ -835,7 +835,7 @@ def expand_dims(v, dims):
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
to_zero = False to_zero = False
if sigmas[-1] == 0: if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
@ -879,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
order = min(3, len(timesteps) - 1) order = min(3, len(timesteps) - 1)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
if not to_zero: if not to_zero:
x /= ns.marginal_alpha(timesteps[-1]) x /= ns.marginal_alpha(timesteps[-1])
return x return x

View File

@ -81,6 +81,7 @@ class DDIMSampler(object):
extra_args=None, extra_args=None,
to_zero=True, to_zero=True,
end_step=None, end_step=None,
disable_pbar=False,
**kwargs **kwargs
): ):
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose) self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
@ -103,7 +104,8 @@ class DDIMSampler(object):
denoise_function=denoise_function, denoise_function=denoise_function,
extra_args=extra_args, extra_args=extra_args,
to_zero=to_zero, to_zero=to_zero,
end_step=end_step end_step=end_step,
disable_pbar=disable_pbar
) )
return samples, intermediates return samples, intermediates
@ -185,7 +187,7 @@ class DDIMSampler(object):
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None): ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -204,7 +206,7 @@ class DDIMSampler(object):
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps") # print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step) iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1

View File

@ -56,7 +56,7 @@ def cleanup_additional_models(models):
for m in models: for m in models:
m.cleanup() m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
if noise_mask is not None: if noise_mask is not None:
@ -76,7 +76,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
samples = samples.cpu() samples = samples.cpu()
cleanup_additional_models(models) cleanup_additional_models(models)

View File

@ -541,7 +541,7 @@ class KSampler:
sigmas = self.calculate_sigmas(new_steps).to(self.device) sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):] self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False):
if sigmas is None: if sigmas is None:
sigmas = self.sigmas sigmas = self.sigmas
sigma_min = self.sigma_min sigma_min = self.sigma_min
@ -610,9 +610,9 @@ class KSampler:
with precision_scope(model_management.get_autocast_device(self.device)): with precision_scope(model_management.get_autocast_device(self.device)):
if self.sampler == "uni_pc": if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback) samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2": elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2') samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim": elif self.sampler == "ddim":
timesteps = [] timesteps = []
for s in range(sigmas.shape[0]): for s in range(sigmas.shape[0]):
@ -643,7 +643,8 @@ class KSampler:
extra_args=extra_args, extra_args=extra_args,
mask=noise_mask, mask=noise_mask,
to_zero=sigmas[-1]==0, to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1) end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
else: else:
extra_args["denoise_mask"] = denoise_mask extra_args["denoise_mask"] = denoise_mask
@ -659,10 +660,10 @@ class KSampler:
if latent_image is not None: if latent_image is not None:
noise += latent_image noise += latent_image
if self.sampler == "dpm_fast": if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback) samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive": elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else: else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
return samples.to(torch.float32) return samples.to(torch.float32)

View File

@ -111,6 +111,8 @@ def load_lora(path, to_load):
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x) hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x) hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x) hada_w2_a_name = "{}.hada_w2_a".format(x)
@ -132,6 +134,54 @@ def load_lora(path, to_load):
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name) loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) print("lora key not loaded", x)
@ -315,6 +365,33 @@ class ModelPatcher:
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float(), w2_b.float())
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
else: #loha else: #loha
w1a = v[0] w1a = v[0]
w1b = v[1] w1b = v[1]

View File

@ -232,10 +232,27 @@ app.registerExtension({
"name": "My Color Palette", "name": "My Color Palette",
"colors": { "colors": {
"node_slot": { "node_slot": {
},
"litegraph_base": {
},
"comfy_base": {
} }
} }
}; };
// Copy over missing keys from default color palette
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
for (const key in defaultColorPalette.colors.litegraph_base) {
if (!colorPalette.colors.litegraph_base[key]) {
colorPalette.colors.litegraph_base[key] = "";
}
}
for (const key in defaultColorPalette.colors.comfy_base) {
if (!colorPalette.colors.comfy_base[key]) {
colorPalette.colors.comfy_base[key] = "";
}
}
return completeColorPalette(colorPalette); return completeColorPalette(colorPalette);
}; };

View File

@ -257,8 +257,11 @@ button.comfy-queue-btn {
} }
} }
/* Input popup */
.graphdialog { .graphdialog {
min-height: 1em; min-height: 1em;
background-color: var(--comfy-menu-bg);
} }
.graphdialog .name { .graphdialog .name {
@ -282,18 +285,66 @@ button.comfy-queue-btn {
border-radius: 12px 0 0 12px; border-radius: 12px 0 0 12px;
} }
/* Context menu */
.litegraph .litemenu-entry.has_submenu { .litegraph .litemenu-entry.has_submenu {
position: relative; position: relative;
padding-right: 20px; padding-right: 20px;
} }
.litemenu-entry.has_submenu::after { .litemenu-entry.has_submenu::after {
content: ">"; content: ">";
position: absolute; position: absolute;
top: 0; top: 0;
right: 2px; right: 2px;
} }
.litecontextmenu { .litegraph.litecontextmenu,
z-index: 9999 !important; .litegraph.litecontextmenu.dark {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
filter: brightness(95%);
}
.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) {
background-color: var(--comfy-menu-bg) !important;
filter: brightness(155%);
color: var(--input-text);
}
.litegraph.litecontextmenu .litemenu-entry.submenu,
.litegraph.litecontextmenu.dark .litemenu-entry.submenu {
background-color: var(--comfy-menu-bg) !important;
color: var(--input-text);
}
.litegraph.litecontextmenu input {
background-color: var(--comfy-input-bg) !important;
color: var(--input-text) !important;
}
/* Search box */
.litegraph.litesearchbox {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
overflow: hidden;
}
.litegraph.litesearchbox input,
.litegraph.litesearchbox select {
background-color: var(--comfy-input-bg) !important;
color: var(--input-text);
}
.litegraph.lite-search-item {
color: var(--input-text);
background-color: var(--comfy-input-bg);
filter: brightness(80%);
padding-left: 0.2em;
}
.litegraph.lite-search-item.generic_type {
color: var(--input-text);
filter: brightness(50%);
} }