diff --git a/README.md b/README.md index 5e32a74f3..ad85d3d49 100644 --- a/README.md +++ b/README.md @@ -93,8 +93,8 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` -This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt``` +This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements: +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6 -r requirements.txt``` ### NVIDIA diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 29e5fb159..83d8cd287 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -42,7 +42,7 @@ parser.add_argument("--auto-launch", action="store_true", help="Automatically la parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") -cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Enable cudaMallocAsync.") +cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") @@ -84,6 +84,8 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") +parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") + args = parser.parse_args() if args.windows_standalone_build: diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 9688cbd52..a9eb9302f 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -148,6 +148,10 @@ vae_conversion_map_attn = [ ("q.", "query."), ("k.", "key."), ("v.", "value."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), ("proj_out.", "proj_attn."), ] diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py index 7335d56c4..c1a137d9c 100644 --- a/comfy/k_diffusion/external.py +++ b/comfy/k_diffusion/external.py @@ -91,7 +91,9 @@ class DiscreteSchedule(nn.Module): return log_sigma.exp() def predict_eps_discrete_timestep(self, input, t, **kwargs): - sigma = self.t_to_sigma(t.round()) + if t.dtype != torch.int64 and t.dtype != torch.int32: + t = t.round() + sigma = self.t_to_sigma(t) input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 3b4e99315..dd234435f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -3,7 +3,6 @@ import math from scipy import integrate import torch from torch import nn -from torchdiffeq import odeint import torchsde from tqdm.auto import trange, tqdm @@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o return x -@torch.no_grad() -def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - v = torch.randint_like(x, 2) * 2 - 1 - fevals = 0 - def ode_fn(sigma, x): - nonlocal fevals - with torch.enable_grad(): - x = x[0].detach().requires_grad_() - denoised = model(x, sigma * s_in, **extra_args) - d = to_d(x, sigma, denoised) - fevals += 1 - grad = torch.autograd.grad((d * v).sum(), x)[0] - d_ll = (v * grad).flatten(1).sum(1) - return d.detach(), d_ll - x_min = x, x.new_zeros([x.shape[0]]) - t = x.new_tensor([sigma_min, sigma_max]) - sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') - latent, delta_ll = sol[0][-1], sol[1][-1] - ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) - return ll_prior + delta_ll, {'fevals': fevals} - - class PIDStepSizeController: """A PID controller for ODE adaptive step size control.""" def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): diff --git a/comfy/model_base.py b/comfy/model_base.py index 2d2d35814..d35f02a5b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -164,7 +164,6 @@ class SDXLRefiner(BaseModel): else: aesthetic_score = kwargs.get("aesthetic_score", 6) - print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score) out = [] out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([width]))) @@ -188,7 +187,6 @@ class SDXL(BaseModel): target_width = kwargs.get("target_width", width) target_height = kwargs.get("target_height", height) - print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height) out = [] out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([width]))) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index cf764e0b7..691d4c6c4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -118,3 +118,57 @@ def model_config_from_unet_config(unet_config): def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) return model_config_from_unet_config(unet_config) + + +def model_config_from_diffusers_unet(state_dict, use_fp16): + match = {} + match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] + match["model_channels"] = state_dict["conv_in.weight"].shape[0] + match["in_channels"] = state_dict["conv_in.weight"].shape[1] + match["adm_in_channels"] = None + if "class_embedding.linear_1.weight" in state_dict: + match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1] + elif "add_embedding.linear_1.weight" in state_dict: + match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] + + SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + + SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} + + SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + + SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] + + for unet_config in supported_models: + matches = True + for k in match: + if match[k] != unet_config[k]: + matches = False + break + if matches: + return model_config_from_unet_config(unet_config) + return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 34d22429a..241706925 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -49,6 +49,7 @@ except: try: if torch.backends.mps.is_available(): cpu_state = CPUState.MPS + import torch.mps except: pass @@ -280,19 +281,23 @@ def load_model_gpu(model): vram_set_state = VRAMState.LOW_VRAM real_model = model.model + patch_model_to = None if vram_set_state == VRAMState.DISABLED: pass elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False - real_model.to(torch_dev) + patch_model_to = torch_dev try: - real_model = model.patch_model() + real_model = model.patch_model(device_to=patch_model_to) except Exception as e: model.unpatch_model() unload_model() raise e + if patch_model_to is not None: + real_model.to(torch_dev) + if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) @@ -529,7 +534,7 @@ def should_use_fp16(device=None, model_params=0): return False #FP16 is just broken on these cards - nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"] + nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450"] for x in nvidia_16_series: if x in props.name: return False diff --git a/comfy/samplers.py b/comfy/samplers.py index 50fda016d..044d518a5 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -17,6 +17,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 + if 'timestep_start' in cond[1]: + timestep_start = cond[1]['timestep_start'] + if timestep_in[0] > timestep_start: + return None + if 'timestep_end' in cond[1]: + timestep_end = cond[1]['timestep_end'] + if timestep_in[0] < timestep_end: + return None if 'area' in cond[1]: area = cond[1]['area'] if 'strength' in cond[1]: @@ -248,7 +256,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c['transformer_options'] = transformer_options - output = model_function(input_x, timestep_, **c).chunk(batch_chunks) + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + else: + output = model_function(input_x, timestep_, **c).chunk(batch_chunks) del input_x model_management.throw_exception_if_processing_interrupted() @@ -425,6 +436,35 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] +def calculate_start_end_timesteps(model, conds): + for t in range(len(conds)): + x = conds[t] + + timestep_start = None + timestep_end = None + if 'start_percent' in x[1]: + timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0))) + if 'end_percent' in x[1]: + timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0))) + + if (timestep_start is not None) or (timestep_end is not None): + n = x[1].copy() + if (timestep_start is not None): + n['timestep_start'] = timestep_start + if (timestep_end is not None): + n['timestep_end'] = timestep_end + conds[t] = [x[0], n] + +def pre_run_control(model, conds): + for t in range(len(conds)): + x = conds[t] + + timestep_start = None + timestep_end = None + percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) + if 'control' in x[1]: + x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function) + def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] cond_other = [] @@ -568,13 +608,18 @@ class KSampler: resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + calculate_start_end_timesteps(self.model_wrap, negative) + calculate_start_end_timesteps(self.model_wrap, positive) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) for c in negative: create_cond_with_same_area_if_none(positive, c) - apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + pre_run_control(self.model_wrap, negative + positive) + + apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if self.model.is_adm(): diff --git a/comfy/sd.py b/comfy/sd.py index a7887a82b..70701ab6b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -170,6 +170,8 @@ def model_lora_keys_clip(model, key_map={}): if k in sdk: lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: @@ -202,6 +204,14 @@ def model_lora_keys_unet(model, key_map={}): key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k]) return key_map +def set_attr(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + setattr(obj, attrs[-1], torch.nn.Parameter(value)) + del prev + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0): self.size = size @@ -330,7 +340,7 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self): + def patch_model(self, device_to=None): model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: @@ -340,10 +350,14 @@ class ModelPatcher: weight = model_sd[key] if key not in self.backup: - self.backup[key] = weight.to(self.offload_device, copy=True) + self.backup[key] = weight.to(self.offload_device) - temp_weight = weight.to(torch.float32, copy=True) - weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if device_to is not None: + temp_weight = weight.float().to(device_to, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + set_attr(self.model, key, out_weight) del temp_weight return self.model @@ -376,7 +390,10 @@ class ModelPatcher: mat3 = v[3].float().to(weight.device) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + try: + weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) elif len(v) == 8: #lokr w1 = v[0] w2 = v[1] @@ -407,7 +424,10 @@ class ModelPatcher: if v[2] is not None and dim is not None: alpha *= v[2] / dim - weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + try: + weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) else: #loha w1a = v[0] w1b = v[1] @@ -424,18 +444,15 @@ class ModelPatcher: m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + try: + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + except Exception as e: + print("ERROR", key, e) + return weight def unpatch_model(self): keys = list(self.backup.keys()) - def set_attr(obj, attr, value): - attrs = attr.split(".") - for name in attrs[:-1]: - obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value)) - del prev for k in keys: set_attr(self.model, k, self.backup[k]) @@ -658,16 +675,57 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): else: return torch.cat([tensor] * batched_number, dim=0) -class ControlNet: - def __init__(self, control_model, global_average_pooling=False, device=None): - self.control_model = control_model +class ControlBase: + def __init__(self, device=None): self.cond_hint_original = None self.cond_hint = None self.strength = 1.0 + self.timestep_percent_range = (1.0, 0.0) + self.timestep_range = None + if device is None: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): + self.cond_hint_original = cond_hint + self.strength = strength + self.timestep_percent_range = timestep_percent_range + return self + + def pre_run(self, model, percent_to_timestep_function): + self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1])) + if self.previous_controlnet is not None: + self.previous_controlnet.pre_run(model, percent_to_timestep_function) + + def set_previous_controlnet(self, controlnet): + self.previous_controlnet = controlnet + return self + + def cleanup(self): + if self.previous_controlnet is not None: + self.previous_controlnet.cleanup() + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.timestep_range = None + + def get_models(self): + out = [] + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_models() + return out + + def copy_to(self, c): + c.cond_hint_original = self.cond_hint_original + c.strength = self.strength + c.timestep_percent_range = self.timestep_percent_range + +class ControlNet(ControlBase): + def __init__(self, control_model, global_average_pooling=False, device=None): + super().__init__(device) + self.control_model = control_model self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond, batched_number): @@ -675,6 +733,13 @@ class ControlNet: if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: @@ -722,37 +787,64 @@ class ControlNet: out['input'] = control_prev['input'] return out - def set_cond_hint(self, cond_hint, strength=1.0): - self.cond_hint_original = cond_hint - self.strength = strength - return self - - def set_previous_controlnet(self, controlnet): - self.previous_controlnet = controlnet - return self - - def cleanup(self): - if self.previous_controlnet is not None: - self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - def copy(self): c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) - c.cond_hint_original = self.cond_hint_original - c.strength = self.strength + self.copy_to(c) return c def get_models(self): - out = [] - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_models() + out = super().get_models() out.append(self.control_model) return out + def load_controlnet(ckpt_path, model=None): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) + + controlnet_config = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format + use_fp16 = model_management.should_use_fp16() + controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config + diffusers_keys = utils.unet_to_diffusers(controlnet_config) + diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" + diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + k_in = "controlnet_down_blocks.{}{}".format(count, s) + k_out = "zero_convs.{}.0{}".format(count, s) + if k_in not in controlnet_data: + loop = False + break + diffusers_keys[k_in] = k_out + count += 1 + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + if count == 0: + k_in = "controlnet_cond_embedding.conv_in{}".format(s) + else: + k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) + k_out = "input_hint_block.{}{}".format(count * 2, s) + if k_in not in controlnet_data: + k_in = "controlnet_cond_embedding.conv_out{}".format(s) + loop = False + diffusers_keys[k_in] = k_out + count += 1 + + new_sd = {} + for k in diffusers_keys: + if k in controlnet_data: + new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + + controlnet_data = new_sd + pth_key = 'control_model.zero_convs.0.0.weight' pth = False key = 'zero_convs.0.0.weight' @@ -768,9 +860,9 @@ def load_controlnet(ckpt_path, model=None): print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) return net - use_fp16 = model_management.should_use_fp16() - - controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config + if controlnet_config is None: + use_fp16 = model_management.should_use_fp16() + controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = 3 control_model = cldm.ControlNet(**controlnet_config) @@ -810,24 +902,25 @@ def load_controlnet(ckpt_path, model=None): control = ControlNet(control_model, global_average_pooling=global_average_pooling) return control -class T2IAdapter: +class T2IAdapter(ControlBase): def __init__(self, t2i_model, channels_in, device=None): + super().__init__(device) self.t2i_model = t2i_model self.channels_in = channels_in - self.strength = 1.0 - if device is None: - device = model_management.get_torch_device() - self.device = device - self.previous_controlnet = None self.control_input = None - self.cond_hint_original = None - self.cond_hint = None def get_control(self, x_noisy, t, cond, batched_number): control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return {} + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint @@ -872,33 +965,11 @@ class T2IAdapter: out['output'] = control_prev['output'] return out - def set_cond_hint(self, cond_hint, strength=1.0): - self.cond_hint_original = cond_hint - self.strength = strength - return self - - def set_previous_controlnet(self, controlnet): - self.previous_controlnet = controlnet - return self - def copy(self): c = T2IAdapter(self.t2i_model, self.channels_in) - c.cond_hint_original = self.cond_hint_original - c.strength = self.strength + self.copy_to(c) return c - def cleanup(self): - if self.previous_controlnet is not None: - self.previous_controlnet.cleanup() - if self.cond_hint is not None: - del self.cond_hint - self.cond_hint = None - - def get_models(self): - out = [] - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_models() - return out def load_t2i_adapter(t2i_data): keys = t2i_data.keys() @@ -1128,66 +1199,24 @@ def load_unet(unet_path): #load unet in diffusers format parameters = calculate_parameters(sd, "") fp16 = model_management.should_use_fp16(model_params=parameters) - match = {} - match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] - match["model_channels"] = sd["conv_in.weight"].shape[0] - match["in_channels"] = sd["conv_in.weight"].shape[1] - match["adm_in_channels"] = None - if "class_embedding.linear_1.weight" in sd: - match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1] + model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) + if model_config is None: + print("ERROR UNSUPPORTED UNET", unet_path) + return None - SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + diffusers_keys = utils.unet_to_diffusers(model_config.unet_config) - SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} - - SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} - - SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} - - SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} - - SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} - - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] - print("match", match) - for unet_config in supported_models: - matches = True - for k in match: - if match[k] != unet_config[k]: - matches = False - break - if matches: - diffusers_keys = utils.unet_to_diffusers(unet_config) - new_sd = {} - for k in diffusers_keys: - if k in sd: - new_sd[diffusers_keys[k]] = sd.pop(k) - else: - print(diffusers_keys[k], k) - offload_device = model_management.unet_offload_device() - model_config = model_detection.model_config_from_unet_config(unet_config) - model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") - return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + print(diffusers_keys[k], k) + offload_device = model_management.unet_offload_device() + model = model_config.get_model(new_sd, "") + model = model.to(offload_device) + model.load_model_weights(new_sd, "") + return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) def save_checkpoint(output_path, model, clip, vae, metadata=None): try: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 915214081..b1c01fe87 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -126,7 +126,8 @@ class SDXLRefiner(supported_models_base.BASE): def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") - state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") + if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: + state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") replace_prefix["clip_g"] = "conditioner.embedders.0.model" state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g @@ -171,7 +172,8 @@ class SDXL(supported_models_base.BASE): replace_prefix = {} keys_to_replace = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") - state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") + if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: + state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") for k in state_dict: if k.startswith("clip_l"): state_dict_g[k] = state_dict[k] diff --git a/comfy/utils.py b/comfy/utils.py index d410e6af6..3bbe4f9a9 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -120,20 +120,24 @@ UNET_MAP_RESNET = { } UNET_MAP_BASIC = { - "label_emb.0.0.weight": "class_embedding.linear_1.weight", - "label_emb.0.0.bias": "class_embedding.linear_1.bias", - "label_emb.0.2.weight": "class_embedding.linear_2.weight", - "label_emb.0.2.bias": "class_embedding.linear_2.bias", - "input_blocks.0.0.weight": "conv_in.weight", - "input_blocks.0.0.bias": "conv_in.bias", - "out.0.weight": "conv_norm_out.weight", - "out.0.bias": "conv_norm_out.bias", - "out.2.weight": "conv_out.weight", - "out.2.bias": "conv_out.bias", - "time_embed.0.weight": "time_embedding.linear_1.weight", - "time_embed.0.bias": "time_embedding.linear_1.bias", - "time_embed.2.weight": "time_embedding.linear_2.weight", - "time_embed.2.bias": "time_embedding.linear_2.bias" + ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), + ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias") } def unet_to_diffusers(unet_config): @@ -208,7 +212,7 @@ def unet_to_diffusers(unet_config): n += 1 for k in UNET_MAP_BASIC: - diffusers_unet_map[UNET_MAP_BASIC[k]] = k + diffusers_unet_map[k[1]] = k[0] return diffusers_unet_map diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 95c4cfece..bce4b3dd0 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -1,9 +1,13 @@ import comfy.sd import comfy.utils +import comfy.model_base + import folder_paths import json import os +from comfy.cli_args import args + class ModelMergeSimple: @classmethod def INPUT_TYPES(s): @@ -99,10 +103,36 @@ class CheckpointSave: if prompt is not None: prompt_info = json.dumps(prompt) - metadata = {"prompt": prompt_info} - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) + metadata = {} + + enable_modelspec = True + if isinstance(model.model, comfy.model_base.SDXL): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" + elif isinstance(model.model, comfy.model_base.SDXLRefiner): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" + else: + enable_modelspec = False + + if enable_modelspec: + metadata["modelspec.sai_model_spec"] = "1.0.0" + metadata["modelspec.implementation"] = "sgm" + metadata["modelspec.title"] = "{} {}".format(filename, counter) + + #TODO: + # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", + # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", + # "v2-inpainting" + + if model.model.model_type == comfy.model_base.ModelType.EPS: + metadata["modelspec.predict_key"] = "epsilon" + elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: + metadata["modelspec.predict_key"] = "v" + + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index f9252ea0b..abd182e6e 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -37,12 +37,23 @@ class ImageUpscaleWithModel: device = model_management.get_torch_device() upscale_model.to(device) in_img = image.movedim(-1,-3).to(device) + free_memory = model_management.get_free_memory(device) + + tile = 512 + overlap = 32 + + oom = True + while oom: + try: + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) + pbar = comfy.utils.ProgressBar(steps) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + oom = False + except model_management.OOM_EXCEPTION as e: + tile //= 2 + if tile < 128: + raise e - tile = 128 + 64 - overlap = 8 - steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) - pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return (s,) diff --git a/cuda_malloc.py b/cuda_malloc.py new file mode 100644 index 000000000..a808b2071 --- /dev/null +++ b/cuda_malloc.py @@ -0,0 +1,81 @@ +import os +import importlib.util +from comfy.cli_args import args + +#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. +def get_gpu_names(): + if os.name == 'nt': + import ctypes + + # Define necessary C structures and types + class DISPLAY_DEVICEA(ctypes.Structure): + _fields_ = [ + ('cb', ctypes.c_ulong), + ('DeviceName', ctypes.c_char * 32), + ('DeviceString', ctypes.c_char * 128), + ('StateFlags', ctypes.c_ulong), + ('DeviceID', ctypes.c_char * 128), + ('DeviceKey', ctypes.c_char * 128) + ] + + # Load user32.dll + user32 = ctypes.windll.user32 + + # Call EnumDisplayDevicesA + def enum_display_devices(): + device_info = DISPLAY_DEVICEA() + device_info.cb = ctypes.sizeof(device_info) + device_index = 0 + gpu_names = set() + + while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0): + device_index += 1 + gpu_names.add(device_info.DeviceString.decode('utf-8')) + return gpu_names + return enum_display_devices() + else: + return set() + +def cuda_malloc_supported(): + blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", + "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620", + "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", + "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000"} + + try: + names = get_gpu_names() + except: + names = set() + for x in names: + if "NVIDIA" in x: + for b in blacklist: + if b in x: + return False + return True + + +if not args.cuda_malloc: + try: + version = "" + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ + if int(version[0]) >= 2: #enable by default for torch version 2.0 and up + args.cuda_malloc = cuda_malloc_supported() + except: + pass + + +if args.cuda_malloc and not args.disable_cuda_malloc: + env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) + if env_var is None: + env_var = "backend:cudaMallocAsync" + else: + env_var += ",backend:cudaMallocAsync" + + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var diff --git a/latent_preview.py b/latent_preview.py index 833e6822e..30c1d1317 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -1,6 +1,5 @@ import torch -from PIL import Image, ImageOps -from io import BytesIO +from PIL import Image import struct import numpy as np from comfy.cli_args import args, LatentPreviewMethod @@ -15,26 +14,7 @@ class LatentPreviewer: def decode_latent_to_preview_image(self, preview_format, x0): preview_image = self.decode_latent_to_preview(x0) - - if hasattr(Image, 'Resampling'): - resampling = Image.Resampling.BILINEAR - else: - resampling = Image.ANTIALIAS - - preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), resampling) - - preview_type = 1 - if preview_format == "JPEG": - preview_type = 1 - elif preview_format == "PNG": - preview_type = 2 - - bytesIO = BytesIO() - header = struct.pack(">I", preview_type) - bytesIO.write(header) - preview_image.save(bytesIO, format=preview_format, quality=95) - preview_bytes = bytesIO.getvalue() - return preview_bytes + return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION) class TAESDPreviewerImpl(LatentPreviewer): def __init__(self, taesd): diff --git a/main.py b/main.py index 61ba9e8e6..21f76b617 100644 --- a/main.py +++ b/main.py @@ -61,30 +61,7 @@ if __name__ == "__main__": os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) print("Set cuda device to:", args.cuda_device) - if not args.cuda_malloc: - try: #if there's a better way to check the torch version without importing it let me know - version = "" - torch_spec = importlib.util.find_spec("torch") - for folder in torch_spec.submodule_search_locations: - ver_file = os.path.join(folder, "version.py") - if os.path.isfile(ver_file): - spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - version = module.__version__ - if int(version[0]) >= 2: #enable by default for torch version 2.0 and up - args.cuda_malloc = True - except: - pass - - if args.cuda_malloc and not args.disable_cuda_malloc: - env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) - if env_var is None: - env_var = "backend:cudaMallocAsync" - else: - env_var += ",backend:cudaMallocAsync" - - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + import cuda_malloc import comfy.utils import yaml @@ -115,10 +92,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def hijack_progress(server): - def hook(value, total, preview_image_bytes): + def hook(value, total, preview_image): server.send_sync("progress", {"value": value, "max": total}, server.client_id) - if preview_image_bytes is not None: - server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) + if preview_image is not None: + server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) comfy.utils.set_progress_bar_global_hook(hook) diff --git a/nodes.py b/nodes.py index a1c4b8437..240619ed1 100644 --- a/nodes.py +++ b/nodes.py @@ -26,6 +26,8 @@ import comfy.utils import comfy.clip_vision import comfy.model_management +from comfy.cli_args import args + import importlib import folder_paths @@ -204,6 +206,28 @@ class ConditioningZeroOut: c.append(n) return (c, ) +class ConditioningSetTimestepRange: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "set_range" + + CATEGORY = "advanced/conditioning" + + def set_range(self, conditioning, start, end): + c = [] + for t in conditioning: + d = t[1].copy() + d['start_percent'] = 1.0 - start + d['end_percent'] = 1.0 - end + n = [t[0], d] + c.append(n) + return (c, ) + class VAEDecode: @classmethod def INPUT_TYPES(s): @@ -330,10 +354,12 @@ class SaveLatent: if prompt is not None: prompt_info = json.dumps(prompt) - metadata = {"prompt": prompt_info} - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) + metadata = None + if not args.disable_metadata: + metadata = {"prompt": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) file = f"{filename}_{counter:05}_.latent" file = os.path.join(full_output_folder, file) @@ -580,9 +606,58 @@ class ControlNetApply: if 'control' in t[1]: c_net.set_previous_controlnet(t[1]['control']) n[1]['control'] = c_net + n[1]['control_apply_to_uncond'] = True c.append(n) return (c, ) + +class ControlNetApplyAdvanced: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING") + RETURN_NAMES = ("positive", "negative") + FUNCTION = "apply_controlnet" + + CATEGORY = "conditioning" + + def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent): + if strength == 0: + return (positive, negative) + + control_hint = image.movedim(-1,1) + cnets = {} + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + + prev_cnet = d.get('control', None) + if prev_cnet in cnets: + c_net = cnets[prev_cnet] + else: + c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent)) + c_net.set_previous_controlnet(prev_cnet) + cnets[prev_cnet] = c_net + + d['control'] = c_net + d['control_apply_to_uncond'] = False + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1]) + + class UNETLoader: @classmethod def INPUT_TYPES(s): @@ -1143,12 +1218,14 @@ class SaveImage: for image in images: i = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - metadata = PngInfo() - if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) file = f"{filename}_{counter:05}_.png" img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) @@ -1427,6 +1504,7 @@ NODE_CLASS_MAPPINGS = { "StyleModelApply": StyleModelApply, "unCLIPConditioning": unCLIPConditioning, "ControlNetApply": ControlNetApply, + "ControlNetApplyAdvanced": ControlNetApplyAdvanced, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, "StyleModelLoader": StyleModelLoader, @@ -1444,6 +1522,7 @@ NODE_CLASS_MAPPINGS = { "SaveLatent": SaveLatent, "ConditioningZeroOut": ConditioningZeroOut, + "ConditioningSetTimestepRange": ConditioningSetTimestepRange, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1472,6 +1551,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", + "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "SetLatentNoiseMask": "Set Latent Noise Mask", diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index 61c277bf6..1bb90f7d0 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -69,6 +69,13 @@ "source": [ "# Checkpoints\n", "\n", + "### SDXL\n", + "### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n", + "\n", + "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n", + "\n", + "\n", "# SD1.5\n", "!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n", "\n", @@ -83,7 +90,7 @@ "#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n", "\n", "# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n", - "#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp16.safetensors -P ./models/checkpoints/\n", + "#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n", "\n", "\n", "# unCLIP models\n", @@ -100,6 +107,7 @@ "# Loras\n", "#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n", "#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n", + "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n", "\n", "\n", "# T2I-Adapter\n", diff --git a/requirements.txt b/requirements.txt index d632edf79..8ee7b83d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ torch -torchdiffeq torchsde einops transformers>=4.25.1 diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index a0e22878b..242d3175f 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -2,8 +2,12 @@ import json from urllib import request, parse import random -#this is the ComfyUI api prompt format. If you want it for a specific workflow you can copy it from the prompt section -#of the image metadata of images generated with ComfyUI +#This is the ComfyUI api prompt format. + +#If you want it for a specific workflow you can "enable dev mode options" +#in the settings of the UI (gear beside the "Queue Size: ") this will enable +#a button on the UI to save workflows in api format. + #keep in mind ComfyUI is pre alpha software so this format will change a bit. #this is the one for the default workflow diff --git a/server.py b/server.py index 9ca131ede..f61b11a97 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,7 @@ import uuid import json import glob import struct -from PIL import Image +from PIL import Image, ImageOps from io import BytesIO try: @@ -29,6 +29,7 @@ import comfy.model_management class BinaryEventTypes: PREVIEW_IMAGE = 1 + UNENCODED_PREVIEW_IMAGE = 2 async def send_socket_catch_exception(function, message): try: @@ -498,7 +499,9 @@ class PromptServer(): return prompt_info async def send(self, event, data, sid=None): - if isinstance(data, (bytes, bytearray)): + if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: + await self.send_image(data, sid=sid) + elif isinstance(data, (bytes, bytearray)): await self.send_bytes(event, data, sid) else: await self.send_json(event, data, sid) @@ -512,6 +515,30 @@ class PromptServer(): message.extend(data) return message + async def send_image(self, image_data, sid=None): + image_type = image_data[0] + image = image_data[1] + max_size = image_data[2] + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.ANTIALIAS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + type_num = 1 + if image_type == "JPEG": + type_num = 1 + elif image_type == "PNG": + type_num = 2 + + bytesIO = BytesIO() + header = struct.pack(">I", type_num) + bytesIO.write(header) + image.save(bytesIO, format=image_type, quality=95, compress_level=4) + preview_bytes = bytesIO.getvalue() + await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) + async def send_bytes(self, event, data, sid=None): message = self.encode_bytes(event, data) diff --git a/web/types/comfy.d.ts b/web/types/comfy.d.ts index 8444e13a8..f7129b555 100644 --- a/web/types/comfy.d.ts +++ b/web/types/comfy.d.ts @@ -30,9 +30,7 @@ export interface ComfyExtension { getCustomWidgets( app: ComfyApp ): Promise< - Array< - Record { widget?: IWidget; minWidth?: number; minHeight?: number }> - > + Record { widget?: IWidget; minWidth?: number; minHeight?: number }> >; /** * Allows the extension to add additional handling to the node before it is registered with LGraph