From 85a8900a148c881914ed16900108f08fd26981c1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 22 Jul 2023 11:05:33 -0400 Subject: [PATCH 1/4] Disable cuda malloc on regular GTX 960. --- cuda_malloc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_malloc.py b/cuda_malloc.py index 382432215..faee91a34 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -37,7 +37,7 @@ def get_gpu_names(): return set() def cuda_malloc_supported(): - blacklist = {"GeForce GTX 960M", "GeForce GTX 950M", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745"} + blacklist = {"GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745"} try: names = get_gpu_names() except: From 12a6e93171ce6d67001362f4cd25cc6d279f17ed Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 22 Jul 2023 11:25:49 -0400 Subject: [PATCH 2/4] Del the right object when applying lora. --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 7f7c06bc5..ddafa0b52 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -353,7 +353,7 @@ class ModelPatcher: temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) set_attr(self.model, key, out_weight) - del temp_weight + del weight return self.model def calculate_weight(self, patches, weight, key): From 67be7eb81d59a7997e79a58e12e408e778d976ff Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 22 Jul 2023 17:01:12 -0400 Subject: [PATCH 3/4] Nodes can now patch the unet function. --- comfy/samplers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 50fda016d..9eee25a92 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -248,7 +248,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c['transformer_options'] = transformer_options - output = model_function(input_x, timestep_, **c).chunk(batch_chunks) + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + else: + output = model_function(input_x, timestep_, **c).chunk(batch_chunks) del input_x model_management.throw_exception_if_processing_interrupted() From 22f29d66cae4e65456c3bfcdaa16340c1fef12ab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 22 Jul 2023 21:26:45 -0400 Subject: [PATCH 4/4] Try to fix memory issue with lora. --- comfy/model_management.py | 8 ++++++-- comfy/sd.py | 9 ++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b1afeb715..241706925 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -281,19 +281,23 @@ def load_model_gpu(model): vram_set_state = VRAMState.LOW_VRAM real_model = model.model + patch_model_to = None if vram_set_state == VRAMState.DISABLED: pass elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False - real_model.to(torch_dev) + patch_model_to = torch_dev try: - real_model = model.patch_model() + real_model = model.patch_model(device_to=patch_model_to) except Exception as e: model.unpatch_model() unload_model() raise e + if patch_model_to is not None: + real_model.to(torch_dev) + if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) diff --git a/comfy/sd.py b/comfy/sd.py index ddafa0b52..1f364dd1f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -338,7 +338,7 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self): + def patch_model(self, device_to=None): model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: @@ -350,10 +350,13 @@ class ModelPatcher: if key not in self.backup: self.backup[key] = weight.to(self.offload_device) - temp_weight = weight.to(torch.float32, copy=True) + if device_to is not None: + temp_weight = weight.float().to(device_to, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) set_attr(self.model, key, out_weight) - del weight + del temp_weight return self.model def calculate_weight(self, patches, weight, key):