From cd53b3404ccef4b9fa4d40c389fb28a35603ddbc Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Tue, 15 Aug 2023 18:50:52 -0700 Subject: [PATCH] wip deep_floyd nodes --- .gitignore | 1 + comfy_extras/nodes/deepfloyd/deep_floyd.py | 31 ++++++++++++++-------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index c42639937..a46c3671c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ ![Cc]ustom_nodes/__init__.py !/custom_nodes/example_node.py.example **/put*here +![Mm]odels/deepfloyd/put_deepfloyd_repos_here /extra_model_paths.yaml /.vs .idea/ diff --git a/comfy_extras/nodes/deepfloyd/deep_floyd.py b/comfy_extras/nodes/deepfloyd/deep_floyd.py index 529e19173..507a94757 100644 --- a/comfy_extras/nodes/deepfloyd/deep_floyd.py +++ b/comfy_extras/nodes/deepfloyd/deep_floyd.py @@ -7,7 +7,7 @@ import torch import torchvision.transforms.functional as TF from diffusers import DiffusionPipeline, IFPipeline, StableDiffusionUpscalePipeline, IFSuperResolutionPipeline from diffusers.utils import is_accelerate_available, is_accelerate_version -from transformers import T5EncoderModel +from transformers import T5EncoderModel, BitsAndBytesConfig from comfy.model_management import throw_exception_if_processing_interrupted, get_torch_device, cpu_state, CPUState # todo: this relies on the setup-py cleanup fork @@ -89,8 +89,7 @@ class Loader: return { "required": { "model_name": (Loader._MODELS, {"default": "I-M"}), - "load_in_8bit": ([False, True], {"default": False}), - "device": ("STRING", {"default": ""}), + "quantization": (list(Loader._QUANTIZATIONS.keys()), {"default": "16-bit"}), } } @@ -100,8 +99,19 @@ class Loader: _MODELS = ["I-M", "I-L", "I-XL", "II-M", "II-L", "III", "t5"] + _QUANTIZATIONS = { + "4-bit": BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + ), + "8-bit": BitsAndBytesConfig( + load_in_8bit=True, + ), + "16-bit": None, + } + # todo: correctly use load_in_8bit - def process(self, model_name: str, load_in_8bit: bool, device: str): + def process(self, model_name: str, quantization: str): assert model_name in Loader._MODELS model_v: DiffusionPipeline @@ -113,18 +123,17 @@ class Loader: "feature_extractor": None, "safety_checker": None, "watermarker": None, - "load_in_8bit": load_in_8bit, - # todo: fix diffusers when using device_map auto on multi-gpu setups, layers are not assigned to different devices correctly - "device_map": None if device else "auto" + "device_map": None } + if Loader._QUANTIZATIONS[quantization] is not None: + kwargs['quantization_config'] = Loader._QUANTIZATIONS[quantization] + if model_name == "t5": # find any valid IF model model_path = next(os.path.dirname(file) for file in _find_files(_model_base_path, "model_index.json") if any(x == T5EncoderModel.__name__ for x in json.load(open(file, 'r'))["text_encoder"])) - # todo: this must use load_in_8bit correctly - # kwargs["text_encoder"] = text_encoder kwargs["unet"] = None elif model_name == "III": model_path = f"{_model_base_path}/stable-diffusion-x4-upscaler" @@ -138,8 +147,8 @@ class Loader: **kwargs ) - if device: - model_v = model_v.to(device) + device = get_torch_device() + model_v = model_v.to(device) _cpu_offload(model_v, gpu_id=model_v.device.index)