diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 4bd075e9d..9dfd69977 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -88,6 +88,8 @@ class ControlBase: self.strength = strength self.timestep_percent_range = timestep_percent_range if self.latent_format is not None: + if vae is None: + logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.") self.vae = vae self.extra_concat_orig = extra_concat.copy() if self.concat_mask and len(self.extra_concat_orig) == 0: @@ -222,6 +224,9 @@ class ControlNet(ControlBase): compression_ratio = self.compression_ratio if self.vae is not None: compression_ratio *= self.vae.downscale_ratio + else: + if self.latent_format is not None: + raise ValueError("This Controlnet needs a VAE but none was provided, please use a different ControlNetApply node with a VAE input.") self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") if self.vae is not None: loaded_models = comfy.model_management.loaded_models(only_currently_used=True) @@ -395,19 +400,22 @@ class ControlLora(ControlNet): def controlnet_config(sd, model_options={}): model_config = comfy.model_detection.model_config_from_unet(sd, "", True) - supported_inference_dtypes = model_config.supported_inference_dtypes + unet_dtype = model_options.get("dtype", None) + if unet_dtype is None: + weight_dtype = comfy.utils.weight_dtype(sd) + + supported_inference_dtypes = list(model_config.supported_inference_dtypes) + if weight_dtype is not None: + supported_inference_dtypes.append(weight_dtype) + + unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) - controlnet_config = model_config.unet_config - unet_dtype = model_options.get("dtype", comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)) load_device = comfy.model_management.get_torch_device() manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) operations = model_options.get("custom_operations", None) if operations is None: - if manual_cast_dtype is not None: - operations = comfy.ops.manual_cast - else: - operations = comfy.ops.disable_weight_init + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True) offload_device = comfy.model_management.unet_offload_device() return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device @@ -490,8 +498,8 @@ def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) -def load_controlnet(ckpt_path, model=None, model_options={}): - controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) +def load_controlnet_state_dict(state_dict, model=None, model_options={}): + controlnet_data = state_dict if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT return load_controlnet_hunyuandit(controlnet_data, model_options=model_options) @@ -573,27 +581,35 @@ def load_controlnet(ckpt_path, model=None, model_options={}): else: net = load_t2i_adapter(controlnet_data, model_options=model_options) if net is None: - logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) + logging.error("error could not detect control model type.") return net if controlnet_config is None: model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) - supported_inference_dtypes = model_config.supported_inference_dtypes + supported_inference_dtypes = list(model_config.supported_inference_dtypes) controlnet_config = model_config.unet_config + unet_dtype = model_options.get("dtype", None) + if unet_dtype is None: + weight_dtype = comfy.utils.weight_dtype(controlnet_data) + + if supported_inference_dtypes is None: + supported_inference_dtypes = [comfy.model_management.unet_dtype()] + + if weight_dtype is not None: + supported_inference_dtypes.append(weight_dtype) + + unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) + load_device = comfy.model_management.get_torch_device() - if supported_inference_dtypes is None: - unet_dtype = comfy.model_management.unet_dtype() - else: - unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) - if manual_cast_dtype is not None: - controlnet_config["operations"] = comfy.ops.manual_cast - if "custom_operations" in model_options: - controlnet_config["operations"] = model_options["custom_operations"] - if "dtype" in model_options: - controlnet_config["dtype"] = model_options["dtype"] + operations = model_options.get("custom_operations", None) + if operations is None: + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype) + + controlnet_config["operations"] = operations + controlnet_config["dtype"] = unet_dtype controlnet_config["device"] = comfy.model_management.unet_offload_device() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] @@ -628,14 +644,21 @@ def load_controlnet(ckpt_path, model=None, model_options={}): if len(unexpected) > 0: logging.debug("unexpected controlnet keys: {}".format(unexpected)) - global_average_pooling = False - filename = os.path.splitext(ckpt_path)[0] - if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling - global_average_pooling = True - + global_average_pooling = model_options.get("global_average_pooling", False) control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control +def load_controlnet(ckpt_path, model=None, model_options={}): + if "global_average_pooling" not in model_options: + filename = os.path.splitext(ckpt_path)[0] + if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling + model_options["global_average_pooling"] = True + + cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options) + if cnet is None: + logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) + return cnet + class T2IAdapter(ControlBase): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): super().__init__(device) diff --git a/comfy/model_management.py b/comfy/model_management.py index c15df952d..a43523ada 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -640,6 +640,8 @@ def maximum_vram_for_weights(device=None): return (get_total_memory(device) * 0.88 - minimum_inference_memory()) def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): + if model_params < 0: + model_params = 1000000000000000000000 if args.bf16_unet: return torch.bfloat16 if args.fp16_unet: diff --git a/comfy/ops.py b/comfy/ops.py index 43ed55adb..1b386dba7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -300,10 +300,10 @@ class fp8_ops(manual_cast): return torch.nn.functional.linear(input, weight, bias) -def pick_operations(weight_dtype, compute_dtype, load_device=None): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False): if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init - if args.fast: + if args.fast and not disable_fast_fp8: if comfy.model_management.supports_fp8_compute(load_device): return fp8_ops return manual_cast diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index c4bccaa6f..77b8fb8e2 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -103,5 +103,5 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { # Sampling - "ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT", + "ControlNetApplySD3": "Apply Controlnet", } diff --git a/nodes.py b/nodes.py index 292ff9cfa..c4595f6a1 100644 --- a/nodes.py +++ b/nodes.py @@ -1917,8 +1917,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetMask": "Conditioning (Set Mask)", - "ControlNetApply": "Apply ControlNet", - "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)", + "ControlNetApply": "Apply ControlNet (OLD)", + "ControlNetApplyAdvanced": "Apply ControlNet (OLD Advanced)", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "SetLatentNoiseMask": "Set Latent Noise Mask",