mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
wip deep_floyd nodes
This commit is contained in:
parent
8ecb5c11e9
commit
cd53b3404c
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,6 +8,7 @@
|
||||
![Cc]ustom_nodes/__init__.py
|
||||
!/custom_nodes/example_node.py.example
|
||||
**/put*here
|
||||
![Mm]odels/deepfloyd/put_deepfloyd_repos_here
|
||||
/extra_model_paths.yaml
|
||||
/.vs
|
||||
.idea/
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from diffusers import DiffusionPipeline, IFPipeline, StableDiffusionUpscalePipeline, IFSuperResolutionPipeline
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
from transformers import T5EncoderModel
|
||||
from transformers import T5EncoderModel, BitsAndBytesConfig
|
||||
|
||||
from comfy.model_management import throw_exception_if_processing_interrupted, get_torch_device, cpu_state, CPUState
|
||||
# todo: this relies on the setup-py cleanup fork
|
||||
@ -89,8 +89,7 @@ class Loader:
|
||||
return {
|
||||
"required": {
|
||||
"model_name": (Loader._MODELS, {"default": "I-M"}),
|
||||
"load_in_8bit": ([False, True], {"default": False}),
|
||||
"device": ("STRING", {"default": ""}),
|
||||
"quantization": (list(Loader._QUANTIZATIONS.keys()), {"default": "16-bit"}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -100,8 +99,19 @@ class Loader:
|
||||
|
||||
_MODELS = ["I-M", "I-L", "I-XL", "II-M", "II-L", "III", "t5"]
|
||||
|
||||
_QUANTIZATIONS = {
|
||||
"4-bit": BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
),
|
||||
"8-bit": BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
),
|
||||
"16-bit": None,
|
||||
}
|
||||
|
||||
# todo: correctly use load_in_8bit
|
||||
def process(self, model_name: str, load_in_8bit: bool, device: str):
|
||||
def process(self, model_name: str, quantization: str):
|
||||
assert model_name in Loader._MODELS
|
||||
|
||||
model_v: DiffusionPipeline
|
||||
@ -113,18 +123,17 @@ class Loader:
|
||||
"feature_extractor": None,
|
||||
"safety_checker": None,
|
||||
"watermarker": None,
|
||||
"load_in_8bit": load_in_8bit,
|
||||
# todo: fix diffusers when using device_map auto on multi-gpu setups, layers are not assigned to different devices correctly
|
||||
"device_map": None if device else "auto"
|
||||
"device_map": None
|
||||
}
|
||||
|
||||
if Loader._QUANTIZATIONS[quantization] is not None:
|
||||
kwargs['quantization_config'] = Loader._QUANTIZATIONS[quantization]
|
||||
|
||||
if model_name == "t5":
|
||||
# find any valid IF model
|
||||
model_path = next(os.path.dirname(file) for file in _find_files(_model_base_path, "model_index.json") if
|
||||
any(x == T5EncoderModel.__name__ for x in
|
||||
json.load(open(file, 'r'))["text_encoder"]))
|
||||
# todo: this must use load_in_8bit correctly
|
||||
# kwargs["text_encoder"] = text_encoder
|
||||
kwargs["unet"] = None
|
||||
elif model_name == "III":
|
||||
model_path = f"{_model_base_path}/stable-diffusion-x4-upscaler"
|
||||
@ -138,8 +147,8 @@ class Loader:
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if device:
|
||||
model_v = model_v.to(device)
|
||||
device = get_torch_device()
|
||||
model_v = model_v.to(device)
|
||||
|
||||
_cpu_offload(model_v, gpu_id=model_v.device.index)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user