From a355f38ecc508c49f6a4f592eb67eeee9527e8a7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 21 Sep 2024 01:32:46 -0400 Subject: [PATCH 1/4] Make the SD3 controlnet node the default one. --- comfy_extras/nodes_sd3.py | 2 +- nodes.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index c4bccaa6f..77b8fb8e2 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -103,5 +103,5 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { # Sampling - "ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT", + "ControlNetApplySD3": "Apply Controlnet", } diff --git a/nodes.py b/nodes.py index 292ff9cfa..c4595f6a1 100644 --- a/nodes.py +++ b/nodes.py @@ -1917,8 +1917,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetMask": "Conditioning (Set Mask)", - "ControlNetApply": "Apply ControlNet", - "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)", + "ControlNetApply": "Apply ControlNet (OLD)", + "ControlNetApplyAdvanced": "Apply ControlNet (OLD Advanced)", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "SetLatentNoiseMask": "Set Latent Noise Mask", From 9f7e9f05478cc51c5f3de38a969756c20381cd08 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 21 Sep 2024 01:33:18 -0400 Subject: [PATCH 2/4] Add an error message when a controlnet needs a VAE but none is given. --- comfy/controlnet.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 4bd075e9d..61a67f3f4 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -88,6 +88,8 @@ class ControlBase: self.strength = strength self.timestep_percent_range = timestep_percent_range if self.latent_format is not None: + if vae is None: + logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.") self.vae = vae self.extra_concat_orig = extra_concat.copy() if self.concat_mask and len(self.extra_concat_orig) == 0: @@ -222,6 +224,9 @@ class ControlNet(ControlBase): compression_ratio = self.compression_ratio if self.vae is not None: compression_ratio *= self.vae.downscale_ratio + else: + if self.latent_format is not None: + raise ValueError("This Controlnet needs a VAE but none was provided, please use a different ControlNetApply node with a VAE input.") self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") if self.vae is not None: loaded_models = comfy.model_management.loaded_models(only_currently_used=True) From 2d810b081e3e992105a58b428a70cdd70779c85a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 21 Sep 2024 01:51:51 -0400 Subject: [PATCH 3/4] Add load_controlnet_state_dict function. --- comfy/controlnet.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 61a67f3f4..ff4385b33 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -495,8 +495,8 @@ def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) -def load_controlnet(ckpt_path, model=None, model_options={}): - controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) +def load_controlnet_state_dict(state_dict, model=None, model_options={}): + controlnet_data = state_dict if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT return load_controlnet_hunyuandit(controlnet_data, model_options=model_options) @@ -578,7 +578,7 @@ def load_controlnet(ckpt_path, model=None, model_options={}): else: net = load_t2i_adapter(controlnet_data, model_options=model_options) if net is None: - logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) + logging.error("error could not detect control model type.") return net if controlnet_config is None: @@ -633,14 +633,21 @@ def load_controlnet(ckpt_path, model=None, model_options={}): if len(unexpected) > 0: logging.debug("unexpected controlnet keys: {}".format(unexpected)) - global_average_pooling = False - filename = os.path.splitext(ckpt_path)[0] - if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling - global_average_pooling = True - + global_average_pooling = model_options.get("global_average_pooling", False) control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control +def load_controlnet(ckpt_path, model=None, model_options={}): + if "global_average_pooling" not in model_options: + filename = os.path.splitext(ckpt_path)[0] + if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling + model_options["global_average_pooling"] = True + + cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options) + if cnet is None: + logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) + return cnet + class T2IAdapter(ControlBase): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): super().__init__(device) From dc96a1ae19b1d714a791f1fcb21578389955bbfd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 21 Sep 2024 04:50:12 -0400 Subject: [PATCH 4/4] Load controlnet in fp8 if weights are in fp8. --- comfy/controlnet.py | 47 ++++++++++++++++++++++++--------------- comfy/model_management.py | 2 ++ comfy/ops.py | 4 ++-- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ff4385b33..9dfd69977 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -400,19 +400,22 @@ class ControlLora(ControlNet): def controlnet_config(sd, model_options={}): model_config = comfy.model_detection.model_config_from_unet(sd, "", True) - supported_inference_dtypes = model_config.supported_inference_dtypes + unet_dtype = model_options.get("dtype", None) + if unet_dtype is None: + weight_dtype = comfy.utils.weight_dtype(sd) + + supported_inference_dtypes = list(model_config.supported_inference_dtypes) + if weight_dtype is not None: + supported_inference_dtypes.append(weight_dtype) + + unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) - controlnet_config = model_config.unet_config - unet_dtype = model_options.get("dtype", comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)) load_device = comfy.model_management.get_torch_device() manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) operations = model_options.get("custom_operations", None) if operations is None: - if manual_cast_dtype is not None: - operations = comfy.ops.manual_cast - else: - operations = comfy.ops.disable_weight_init + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True) offload_device = comfy.model_management.unet_offload_device() return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device @@ -583,22 +586,30 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): if controlnet_config is None: model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) - supported_inference_dtypes = model_config.supported_inference_dtypes + supported_inference_dtypes = list(model_config.supported_inference_dtypes) controlnet_config = model_config.unet_config + unet_dtype = model_options.get("dtype", None) + if unet_dtype is None: + weight_dtype = comfy.utils.weight_dtype(controlnet_data) + + if supported_inference_dtypes is None: + supported_inference_dtypes = [comfy.model_management.unet_dtype()] + + if weight_dtype is not None: + supported_inference_dtypes.append(weight_dtype) + + unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) + load_device = comfy.model_management.get_torch_device() - if supported_inference_dtypes is None: - unet_dtype = comfy.model_management.unet_dtype() - else: - unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) - if manual_cast_dtype is not None: - controlnet_config["operations"] = comfy.ops.manual_cast - if "custom_operations" in model_options: - controlnet_config["operations"] = model_options["custom_operations"] - if "dtype" in model_options: - controlnet_config["dtype"] = model_options["dtype"] + operations = model_options.get("custom_operations", None) + if operations is None: + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype) + + controlnet_config["operations"] = operations + controlnet_config["dtype"] = unet_dtype controlnet_config["device"] = comfy.model_management.unet_offload_device() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] diff --git a/comfy/model_management.py b/comfy/model_management.py index 22a584a2e..a97d489d5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -626,6 +626,8 @@ def maximum_vram_for_weights(device=None): return (get_total_memory(device) * 0.88 - minimum_inference_memory()) def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): + if model_params < 0: + model_params = 1000000000000000000000 if args.bf16_unet: return torch.bfloat16 if args.fp16_unet: diff --git a/comfy/ops.py b/comfy/ops.py index 43ed55adb..1b386dba7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -300,10 +300,10 @@ class fp8_ops(manual_cast): return torch.nn.functional.linear(input, weight, bias) -def pick_operations(weight_dtype, compute_dtype, load_device=None): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False): if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init - if args.fast: + if args.fast and not disable_fast_fp8: if comfy.model_management.supports_fp8_compute(load_device): return fp8_ops return manual_cast