mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
Merge branch 'Main' into feature/preview-latent
This commit is contained in:
commit
70070ddb5b
@ -8,7 +8,8 @@ import os.path as osp
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
import diffusers_convert
|
||||
from . import diffusers_convert
|
||||
|
||||
|
||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||
|
||||
@ -215,10 +215,12 @@ class PositionNet(nn.Module):
|
||||
|
||||
def forward(self, boxes, masks, positive_embeddings):
|
||||
B, N, _ = boxes.shape
|
||||
masks = masks.unsqueeze(-1)
|
||||
dtype = self.linears[0].weight.dtype
|
||||
masks = masks.unsqueeze(-1).to(dtype)
|
||||
positive_embeddings = positive_embeddings.to(dtype)
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
||||
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||
@ -252,7 +254,8 @@ class Gligen(nn.Module):
|
||||
|
||||
if self.lowvram == True:
|
||||
self.position_net.cpu()
|
||||
def func_lowvram(key, x):
|
||||
def func_lowvram(x, extra_options):
|
||||
key = extra_options["transformer_index"]
|
||||
module = self.module_list[key]
|
||||
module.to(x.device)
|
||||
r = module(x, objs)
|
||||
|
||||
@ -278,7 +278,7 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
@ -314,7 +314,7 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
first_op_done = True
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
|
||||
@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, dtype=dtype),
|
||||
nn.GroupNorm(32, channels, dtype=dtype),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||
)
|
||||
@ -244,7 +244,7 @@ class ResBlock(TimestepBlock):
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels, dtype=dtype),
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
@ -778,13 +778,13 @@ class UNetModel(nn.Module):
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch, dtype=self.dtype),
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
@ -821,7 +821,7 @@ class UNetModel(nn.Module):
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
|
||||
@ -84,7 +84,7 @@ def _summarize_chunk(
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
torch.exp(attn_weights - max_score, out=attn_weights)
|
||||
exp_weights = attn_weights
|
||||
exp_weights = attn_weights.to(value.dtype)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
@ -166,7 +166,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
attn_scores /= summed
|
||||
attn_probs = attn_scores
|
||||
|
||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||
return hidden_states_slice
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
|
||||
@ -52,7 +52,13 @@ class BaseModel(torch.nn.Module):
|
||||
else:
|
||||
xc = x
|
||||
context = torch.cat(c_crossattn, 1)
|
||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options)
|
||||
dtype = self.get_dtype()
|
||||
xc = xc.to(dtype)
|
||||
t = t.to(dtype)
|
||||
context = context.to(dtype)
|
||||
if c_adm is not None:
|
||||
c_adm = c_adm.to(dtype)
|
||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
@ -108,11 +108,13 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||
unet_config["context_dim"] = context_dim
|
||||
return unet_config
|
||||
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||
def model_config_from_unet_config(unet_config):
|
||||
for model_config in supported_models.models:
|
||||
if model_config.matches(unet_config):
|
||||
return model_config(unet_config)
|
||||
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||
return model_config_from_unet_config(unet_config)
|
||||
|
||||
@ -264,6 +264,7 @@ def load_model_gpu(model):
|
||||
|
||||
torch_dev = model.load_device
|
||||
model.model_patches_to(torch_dev)
|
||||
model.model_patches_to(model.model_dtype())
|
||||
|
||||
if is_device_cpu(torch_dev):
|
||||
vram_set_state = VRAMState.DISABLED
|
||||
|
||||
@ -51,11 +51,11 @@ def get_models_from_cond(cond, model_type):
|
||||
models += [c[1][model_type]]
|
||||
return models
|
||||
|
||||
def load_additional_models(positive, negative):
|
||||
def load_additional_models(positive, negative, dtype):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
|
||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||
gligen = [x[1] for x in gligen]
|
||||
gligen = [x[1].to(dtype) for x in gligen]
|
||||
models = control_nets + gligen
|
||||
comfy.model_management.load_controlnet_gpu(models)
|
||||
return models
|
||||
@ -81,7 +81,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
positive_copy = broadcast_cond(positive, noise.shape[0], device)
|
||||
negative_copy = broadcast_cond(negative, noise.shape[0], device)
|
||||
|
||||
models = load_additional_models(positive, negative)
|
||||
models = load_additional_models(positive, negative, model.model_dtype())
|
||||
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@ from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .k_diffusion import external as k_diffusion_external
|
||||
from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import contextlib
|
||||
from comfy import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
@ -577,11 +576,6 @@ class KSampler:
|
||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if self.model.get_dtype() == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
if self.model.is_adm():
|
||||
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
||||
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
||||
@ -612,67 +606,67 @@ class KSampler:
|
||||
else:
|
||||
max_denoise = True
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
||||
if self.sampler == "uni_pc":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
elif self.sampler == "uni_pc_bh2":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
total_steps = len(timesteps) - 1
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||
if self.sampler == "uni_pc":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
elif self.sampler == "uni_pc_bh2":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
sampler = DDIMSampler(self.model, device=self.device)
|
||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
conditioning=positive,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg,
|
||||
unconditional_conditioning=negative,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=sampling_function,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1,
|
||||
disable_pbar=disable_pbar)
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
total_steps = len(timesteps) - 1
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||
|
||||
sampler = DDIMSampler(self.model, device=self.device)
|
||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
conditioning=positive,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg,
|
||||
unconditional_conditioning=negative,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=sampling_function,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1,
|
||||
disable_pbar=disable_pbar)
|
||||
|
||||
else:
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
self.model_k.latent_image = latent_image
|
||||
self.model_k.noise = noise
|
||||
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
self.model_k.latent_image = latent_image
|
||||
self.model_k.noise = noise
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
noise = noise * sigmas[0]
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if self.sampler == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
elif self.sampler == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if self.sampler == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
elif self.sampler == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
|
||||
return self.model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
72
comfy/sd.py
72
comfy/sd.py
@ -291,7 +291,8 @@ class ModelPatcher:
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
return self.model.get_dtype()
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = {}
|
||||
@ -1049,7 +1050,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd)
|
||||
model = model_config.get_model(sd, "model.diffusion_model.")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
@ -1073,6 +1074,73 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
sd = utils.load_torch_file(unet_path)
|
||||
parameters = calculate_parameters(sd, "")
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
|
||||
match = {}
|
||||
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
||||
match["model_channels"] = sd["conv_in.weight"].shape[0]
|
||||
match["in_channels"] = sd["conv_in.weight"].shape[1]
|
||||
match["adm_in_channels"] = None
|
||||
if "class_embedding.linear_1.weight" in sd:
|
||||
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
||||
|
||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
||||
|
||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
||||
print("match", match)
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
for k in match:
|
||||
if match[k] != unet_config[k]:
|
||||
matches = False
|
||||
break
|
||||
if matches:
|
||||
diffusers_keys = utils.unet_to_diffusers(unet_config)
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in sd:
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
print(diffusers_keys[k], k)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model_config = model_detection.model_config_from_unet_config(unet_config)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||
|
||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||
try:
|
||||
model.patch_model()
|
||||
|
||||
@ -53,9 +53,9 @@ class SD20(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
def v_prediction(self, state_dict):
|
||||
def v_prediction(self, state_dict, prefix=""):
|
||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||
out = state_dict[k]
|
||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
return True
|
||||
@ -109,7 +109,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def get_model(self, state_dict):
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
return model_base.SDXLRefiner(self)
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
@ -144,7 +144,7 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def get_model(self, state_dict):
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
return model_base.SDXL(self)
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
|
||||
@ -41,7 +41,7 @@ class BASE:
|
||||
return False
|
||||
return True
|
||||
|
||||
def v_prediction(self, state_dict):
|
||||
def v_prediction(self, state_dict, prefix=""):
|
||||
return False
|
||||
|
||||
def inpaint_model(self):
|
||||
@ -53,13 +53,13 @@ class BASE:
|
||||
for x in self.unet_extra_config:
|
||||
self.unet_config[x] = self.unet_extra_config[x]
|
||||
|
||||
def get_model(self, state_dict):
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
if self.inpaint_model():
|
||||
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict))
|
||||
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
|
||||
elif self.noise_aug_config is not None:
|
||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
|
||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
|
||||
else:
|
||||
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict))
|
||||
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
@ -117,14 +117,33 @@ UNET_MAP_RESNET = {
|
||||
"out_layers.0.bias": "norm2.bias",
|
||||
}
|
||||
|
||||
UNET_MAP_BASIC = {
|
||||
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
|
||||
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
|
||||
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
|
||||
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
|
||||
"input_blocks.0.0.weight": "conv_in.weight",
|
||||
"input_blocks.0.0.bias": "conv_in.bias",
|
||||
"out.0.weight": "conv_norm_out.weight",
|
||||
"out.0.bias": "conv_norm_out.bias",
|
||||
"out.2.weight": "conv_out.weight",
|
||||
"out.2.bias": "conv_out.bias",
|
||||
"time_embed.0.weight": "time_embedding.linear_1.weight",
|
||||
"time_embed.0.bias": "time_embedding.linear_1.bias",
|
||||
"time_embed.2.weight": "time_embedding.linear_2.weight",
|
||||
"time_embed.2.bias": "time_embedding.linear_2.bias"
|
||||
}
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
attention_resolutions = unet_config["attention_resolutions"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"]
|
||||
num_blocks = len(channel_mult)
|
||||
if not isinstance(num_res_blocks, list):
|
||||
if isinstance(num_res_blocks, int):
|
||||
num_res_blocks = [num_res_blocks] * num_blocks
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = [transformer_depth] * num_blocks
|
||||
|
||||
transformers_per_layer = []
|
||||
res = 1
|
||||
@ -135,7 +154,7 @@ def unet_to_diffusers(unet_config):
|
||||
transformers_per_layer.append(transformers)
|
||||
res *= 2
|
||||
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", transformers_per_layer[-1])
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
|
||||
|
||||
diffusers_unet_map = {}
|
||||
for x in range(num_blocks):
|
||||
@ -185,6 +204,10 @@ def unet_to_diffusers(unet_config):
|
||||
for k in ["weight", "bias"]:
|
||||
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
||||
n += 1
|
||||
|
||||
for k in UNET_MAP_BASIC:
|
||||
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
|
||||
|
||||
return diffusers_unet_map
|
||||
|
||||
def convert_sd_to(state_dict, dtype):
|
||||
|
||||
@ -14,6 +14,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
|
||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||
|
||||
0
models/unet/put_unet_files_here
Normal file
0
models/unet/put_unet_files_here
Normal file
48
nodes.py
48
nodes.py
@ -105,6 +105,34 @@ class ConditioningAverage :
|
||||
out.append(n)
|
||||
return (out, )
|
||||
|
||||
class ConditioningConcat:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning_to": ("CONDITIONING",),
|
||||
"conditioning_from": ("CONDITIONING",),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "concat"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def concat(self, conditioning_to, conditioning_from):
|
||||
out = []
|
||||
|
||||
if len(conditioning_from) > 1:
|
||||
print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||
|
||||
cond_from = conditioning_from[0][0]
|
||||
|
||||
for i in range(len(conditioning_to)):
|
||||
t1 = conditioning_to[i][0]
|
||||
tw = torch.cat((t1, cond_from),1)
|
||||
n = [tw, conditioning_to[i][1].copy()]
|
||||
out.append(n)
|
||||
|
||||
return (out, )
|
||||
|
||||
class ConditioningSetArea:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -520,7 +548,7 @@ class DiffusersLoader:
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
CATEGORY = "advanced/loaders/deprecated"
|
||||
|
||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
@ -675,6 +703,21 @@ class ControlNetApply:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name):
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
model = comfy.sd.load_unet(unet_path)
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1494,6 +1537,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LatentCrop": LatentCrop,
|
||||
"LoraLoader": LoraLoader,
|
||||
"CLIPLoader": CLIPLoader,
|
||||
"UNETLoader": UNETLoader,
|
||||
"DualCLIPLoader": DualCLIPLoader,
|
||||
"CLIPVisionEncode": CLIPVisionEncode,
|
||||
"StyleModelApply": StyleModelApply,
|
||||
@ -1514,7 +1558,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
"LoadLatent": LoadLatent,
|
||||
"SaveLatent": SaveLatent,
|
||||
|
||||
"ConditioningZeroOut": ConditioningZeroOut,
|
||||
"ConditioningConcat": ConditioningConcat,
|
||||
"SavePreviewLatent": SavePreviewLatent,
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user