diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index ba04b9813..11d94c340 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -8,7 +8,8 @@ import os.path as osp import re import torch from safetensors.torch import load_file, save_file -import diffusers_convert +from . import diffusers_convert + def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) diff --git a/comfy/gligen.py b/comfy/gligen.py index fe3895c48..90558785b 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -215,10 +215,12 @@ class PositionNet(nn.Module): def forward(self, boxes, masks, positive_embeddings): B, N, _ = boxes.shape - masks = masks.unsqueeze(-1) + dtype = self.linears[0].weight.dtype + masks = masks.unsqueeze(-1).to(dtype) + positive_embeddings = positive_embeddings.to(dtype) # embedding position (it may includes padding as placeholder) - xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C + xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C # learnable null embedding positive_null = self.null_positive_feature.view(1, 1, -1) @@ -252,7 +254,8 @@ class Gligen(nn.Module): if self.lowvram == True: self.position_net.cpu() - def func_lowvram(key, x): + def func_lowvram(x, extra_options): + key = extra_options["transformer_index"] module = self.module_list[key] module.to(x.device) r = module(x, objs) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 5f9eaa6eb..2284bcbdb 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -278,7 +278,7 @@ class CrossAttentionDoggettx(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) mem_free_total = model_management.get_free_memory(q.device) @@ -314,7 +314,7 @@ class CrossAttentionDoggettx(nn.Module): s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale first_op_done = True - s2 = s1.softmax(dim=-1) + s2 = s1.softmax(dim=-1).to(v.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index b198a270f..92f2438ef 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -220,7 +220,7 @@ class ResBlock(TimestepBlock): self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( - normalization(channels, dtype=dtype), + nn.GroupNorm(32, channels, dtype=dtype), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype), ) @@ -244,7 +244,7 @@ class ResBlock(TimestepBlock): ), ) self.out_layers = nn.Sequential( - normalization(self.out_channels, dtype=dtype), + nn.GroupNorm(32, self.out_channels, dtype=dtype), nn.SiLU(), nn.Dropout(p=dropout), zero_module( @@ -778,13 +778,13 @@ class UNetModel(nn.Module): self._feature_size += ch self.out = nn.Sequential( - normalization(ch, dtype=self.dtype), + nn.GroupNorm(32, ch, dtype=self.dtype), nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - normalization(ch), + nn.GroupNorm(32, ch, dtype=self.dtype), conv_nd(dims, model_channels, n_embed, 1), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) @@ -821,7 +821,7 @@ class UNetModel(nn.Module): self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) emb = self.time_embed(t_emb) if self.num_classes is not None: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 573cce74f..4d42059b5 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -84,7 +84,7 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() torch.exp(attn_weights - max_score, out=attn_weights) - exp_weights = attn_weights + exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) @@ -166,7 +166,7 @@ def _get_attention_scores_no_kv_chunking( attn_scores /= summed attn_probs = attn_scores - hidden_states_slice = torch.bmm(attn_probs, value) + hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value) return hidden_states_slice class ScannedChunk(NamedTuple): diff --git a/comfy/model_base.py b/comfy/model_base.py index 60997246c..9197dc4b9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -52,7 +52,13 @@ class BaseModel(torch.nn.Module): else: xc = x context = torch.cat(c_crossattn, 1) - return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options) + dtype = self.get_dtype() + xc = xc.to(dtype) + t = t.to(dtype) + context = context.to(dtype) + if c_adm is not None: + c_adm = c_adm.to(dtype) + return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float() def get_dtype(self): return self.diffusion_model.dtype diff --git a/comfy/model_detection.py b/comfy/model_detection.py index edad48b1c..cf764e0b7 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -108,11 +108,13 @@ def detect_unet_config(state_dict, key_prefix, use_fp16): unet_config["context_dim"] = context_dim return unet_config - -def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): - unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) +def model_config_from_unet_config(unet_config): for model_config in supported_models.models: if model_config.matches(unet_config): return model_config(unet_config) return None + +def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): + unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) + return model_config_from_unet_config(unet_config) diff --git a/comfy/model_management.py b/comfy/model_management.py index 574fbf214..a918a81f6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -264,6 +264,7 @@ def load_model_gpu(model): torch_dev = model.load_device model.model_patches_to(torch_dev) + model.model_patches_to(model.model_dtype()) if is_device_cpu(torch_dev): vram_set_state = VRAMState.DISABLED diff --git a/comfy/sample.py b/comfy/sample.py index dde5e42f8..48530f132 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -51,11 +51,11 @@ def get_models_from_cond(cond, model_type): models += [c[1][model_type]] return models -def load_additional_models(positive, negative): +def load_additional_models(positive, negative, dtype): """loads additional models in positive and negative conditioning""" control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") - gligen = [x[1] for x in gligen] + gligen = [x[1].to(dtype) for x in gligen] models = control_nets + gligen comfy.model_management.load_controlnet_gpu(models) return models @@ -81,7 +81,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative positive_copy = broadcast_cond(positive, noise.shape[0], device) negative_copy = broadcast_cond(negative, noise.shape[0], device) - models = load_additional_models(positive, negative) + models = load_additional_models(positive, negative, model.model_dtype()) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) diff --git a/comfy/samplers.py b/comfy/samplers.py index ea9525594..b5f79c058 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -2,7 +2,6 @@ from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import external as k_diffusion_external from .extra_samplers import uni_pc import torch -import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps @@ -577,11 +576,6 @@ class KSampler: apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if self.model.get_dtype() == torch.float16: - precision_scope = torch.autocast - else: - precision_scope = contextlib.nullcontext - if self.model.is_adm(): positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive") negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative") @@ -612,67 +606,67 @@ class KSampler: else: max_denoise = True - with precision_scope(model_management.get_autocast_device(self.device)): - if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) - elif self.sampler == "uni_pc_bh2": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) - elif self.sampler == "ddim": - timesteps = [] - for s in range(sigmas.shape[0]): - timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s])) - noise_mask = None - if denoise_mask is not None: - noise_mask = 1.0 - denoise_mask - ddim_callback = None - if callback is not None: - total_steps = len(timesteps) - 1 - ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) + if self.sampler == "uni_pc": + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) + elif self.sampler == "uni_pc_bh2": + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) + elif self.sampler == "ddim": + timesteps = [] + for s in range(sigmas.shape[0]): + timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s])) + noise_mask = None + if denoise_mask is not None: + noise_mask = 1.0 - denoise_mask - sampler = DDIMSampler(self.model, device=self.device) - sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) - z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) - samples, _ = sampler.sample_custom(ddim_timesteps=timesteps, - conditioning=positive, - batch_size=noise.shape[0], - shape=noise.shape[1:], - verbose=False, - unconditional_guidance_scale=cfg, - unconditional_conditioning=negative, - eta=0.0, - x_T=z_enc, - x0=latent_image, - img_callback=ddim_callback, - denoise_function=sampling_function, - extra_args=extra_args, - mask=noise_mask, - to_zero=sigmas[-1]==0, - end_step=sigmas.shape[0] - 1, - disable_pbar=disable_pbar) + ddim_callback = None + if callback is not None: + total_steps = len(timesteps) - 1 + ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) + sampler = DDIMSampler(self.model, device=self.device) + sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) + z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) + samples, _ = sampler.sample_custom(ddim_timesteps=timesteps, + conditioning=positive, + batch_size=noise.shape[0], + shape=noise.shape[1:], + verbose=False, + unconditional_guidance_scale=cfg, + unconditional_conditioning=negative, + eta=0.0, + x_T=z_enc, + x0=latent_image, + img_callback=ddim_callback, + denoise_function=sampling_function, + extra_args=extra_args, + mask=noise_mask, + to_zero=sigmas[-1]==0, + end_step=sigmas.shape[0] - 1, + disable_pbar=disable_pbar) + + else: + extra_args["denoise_mask"] = denoise_mask + self.model_k.latent_image = latent_image + self.model_k.noise = noise + + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) else: - extra_args["denoise_mask"] = denoise_mask - self.model_k.latent_image = latent_image - self.model_k.noise = noise + noise = noise * sigmas[0] - if max_denoise: - noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) - else: - noise = noise * sigmas[0] + k_callback = None + total_steps = len(sigmas) - 1 + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) - k_callback = None - total_steps = len(sigmas) - 1 - if callback is not None: - k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) - - if latent_image is not None: - noise += latent_image - if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) - elif self.sampler == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) - else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + if latent_image is not None: + noise += latent_image + if self.sampler == "dpm_fast": + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + elif self.sampler == "dpm_adaptive": + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) + else: + samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) return self.model.process_latent_out(samples.to(torch.float32)) diff --git a/comfy/sd.py b/comfy/sd.py index 360f2962e..7e64536c1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -291,7 +291,8 @@ class ModelPatcher: patch_list[k] = patch_list[k].to(device) def model_dtype(self): - return self.model.get_dtype() + if hasattr(self.model, "get_dtype"): + return self.model.get_dtype() def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): p = {} @@ -1049,7 +1050,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd) + model = model_config.get_model(sd, "model.diffusion_model.") model = model.to(offload_device) model.load_model_weights(sd, "model.diffusion_model.") @@ -1073,6 +1074,73 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) + +def load_unet(unet_path): #load unet in diffusers format + sd = utils.load_torch_file(unet_path) + parameters = calculate_parameters(sd, "") + fp16 = model_management.should_use_fp16(model_params=parameters) + + match = {} + match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] + match["model_channels"] = sd["conv_in.weight"].shape[0] + match["in_channels"] = sd["conv_in.weight"].shape[1] + match["adm_in_channels"] = None + if "class_embedding.linear_1.weight" in sd: + match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1] + + SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + + SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} + + SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] + print("match", match) + for unet_config in supported_models: + matches = True + for k in match: + if match[k] != unet_config[k]: + matches = False + break + if matches: + diffusers_keys = utils.unet_to_diffusers(unet_config) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + print(diffusers_keys[k], k) + offload_device = model_management.unet_offload_device() + model_config = model_detection.model_config_from_unet_config(unet_config) + model = model_config.get_model(new_sd, "") + model = model.to(offload_device) + model.load_model_weights(new_sd, "") + return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + def save_checkpoint(output_path, model, clip, vae, metadata=None): try: model.patch_model() diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 38a53ca7e..b1beee8c5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -53,9 +53,9 @@ class SD20(supported_models_base.BASE): latent_format = latent_formats.SD15 - def v_prediction(self, state_dict): + def v_prediction(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction - k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" + k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) out = state_dict[k] if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. return True @@ -109,7 +109,7 @@ class SDXLRefiner(supported_models_base.BASE): latent_format = latent_formats.SDXL - def get_model(self, state_dict): + def get_model(self, state_dict, prefix=""): return model_base.SDXLRefiner(self) def process_clip_state_dict(self, state_dict): @@ -144,7 +144,7 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL - def get_model(self, state_dict): + def get_model(self, state_dict, prefix=""): return model_base.SDXL(self) def process_clip_state_dict(self, state_dict): diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 0b0235ca4..86dc67068 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -41,7 +41,7 @@ class BASE: return False return True - def v_prediction(self, state_dict): + def v_prediction(self, state_dict, prefix=""): return False def inpaint_model(self): @@ -53,13 +53,13 @@ class BASE: for x in self.unet_extra_config: self.unet_config[x] = self.unet_extra_config[x] - def get_model(self, state_dict): + def get_model(self, state_dict, prefix=""): if self.inpaint_model(): - return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict)) + return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix)) elif self.noise_aug_config is not None: - return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict)) + return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix)) else: - return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict)) + return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix)) def process_clip_state_dict(self, state_dict): return state_dict diff --git a/comfy/utils.py b/comfy/utils.py index 25ccd944d..956ac1773 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -117,14 +117,33 @@ UNET_MAP_RESNET = { "out_layers.0.bias": "norm2.bias", } +UNET_MAP_BASIC = { + "label_emb.0.0.weight": "class_embedding.linear_1.weight", + "label_emb.0.0.bias": "class_embedding.linear_1.bias", + "label_emb.0.2.weight": "class_embedding.linear_2.weight", + "label_emb.0.2.bias": "class_embedding.linear_2.bias", + "input_blocks.0.0.weight": "conv_in.weight", + "input_blocks.0.0.bias": "conv_in.bias", + "out.0.weight": "conv_norm_out.weight", + "out.0.bias": "conv_norm_out.bias", + "out.2.weight": "conv_out.weight", + "out.2.bias": "conv_out.bias", + "time_embed.0.weight": "time_embedding.linear_1.weight", + "time_embed.0.bias": "time_embedding.linear_1.bias", + "time_embed.2.weight": "time_embedding.linear_2.weight", + "time_embed.2.bias": "time_embedding.linear_2.bias" +} + def unet_to_diffusers(unet_config): num_res_blocks = unet_config["num_res_blocks"] attention_resolutions = unet_config["attention_resolutions"] channel_mult = unet_config["channel_mult"] transformer_depth = unet_config["transformer_depth"] num_blocks = len(channel_mult) - if not isinstance(num_res_blocks, list): + if isinstance(num_res_blocks, int): num_res_blocks = [num_res_blocks] * num_blocks + if isinstance(transformer_depth, int): + transformer_depth = [transformer_depth] * num_blocks transformers_per_layer = [] res = 1 @@ -135,7 +154,7 @@ def unet_to_diffusers(unet_config): transformers_per_layer.append(transformers) res *= 2 - transformers_mid = unet_config.get("transformer_depth_middle", transformers_per_layer[-1]) + transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1]) diffusers_unet_map = {} for x in range(num_blocks): @@ -185,6 +204,10 @@ def unet_to_diffusers(unet_config): for k in ["weight", "bias"]: diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k) n += 1 + + for k in UNET_MAP_BASIC: + diffusers_unet_map[UNET_MAP_BASIC[k]] = k + return diffusers_unet_map def convert_sd_to(state_dict, dtype): diff --git a/folder_paths.py b/folder_paths.py index 2ad1b1719..eb7d39b88 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -14,6 +14,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions) folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) +folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) diff --git a/models/unet/put_unet_files_here b/models/unet/put_unet_files_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 867715545..e722adb54 100644 --- a/nodes.py +++ b/nodes.py @@ -105,6 +105,34 @@ class ConditioningAverage : out.append(n) return (out, ) +class ConditioningConcat: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "conditioning_to": ("CONDITIONING",), + "conditioning_from": ("CONDITIONING",), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "concat" + + CATEGORY = "advanced/conditioning" + + def concat(self, conditioning_to, conditioning_from): + out = [] + + if len(conditioning_from) > 1: + print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + + cond_from = conditioning_from[0][0] + + for i in range(len(conditioning_to)): + t1 = conditioning_to[i][0] + tw = torch.cat((t1, cond_from),1) + n = [tw, conditioning_to[i][1].copy()] + out.append(n) + + return (out, ) + class ConditioningSetArea: @classmethod def INPUT_TYPES(s): @@ -520,7 +548,7 @@ class DiffusersLoader: RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" - CATEGORY = "advanced/loaders" + CATEGORY = "advanced/loaders/deprecated" def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): @@ -675,6 +703,21 @@ class ControlNetApply: c.append(n) return (c, ) +class UNETLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_unet" + + CATEGORY = "advanced/loaders" + + def load_unet(self, unet_name): + unet_path = folder_paths.get_full_path("unet", unet_name) + model = comfy.sd.load_unet(unet_path) + return (model,) + class CLIPLoader: @classmethod def INPUT_TYPES(s): @@ -1494,6 +1537,7 @@ NODE_CLASS_MAPPINGS = { "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, "CLIPLoader": CLIPLoader, + "UNETLoader": UNETLoader, "DualCLIPLoader": DualCLIPLoader, "CLIPVisionEncode": CLIPVisionEncode, "StyleModelApply": StyleModelApply, @@ -1514,7 +1558,9 @@ NODE_CLASS_MAPPINGS = { "LoadLatent": LoadLatent, "SaveLatent": SaveLatent, + "ConditioningZeroOut": ConditioningZeroOut, + "ConditioningConcat": ConditioningConcat, "SavePreviewLatent": SavePreviewLatent, }