diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 46dcf5be8..24e3e5fcd 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -293,6 +293,9 @@ class ModelPatcher: def lowvram_patch_counter(self): return self.model.lowvram_patch_counter + def get_free_memory(self, device): + return comfy.model_management.get_free_memory(device) + def clone(self): n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} diff --git a/comfy/samplers.py b/comfy/samplers.py index 1989ef107..d495ca203 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -260,7 +260,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(x_in.device) + free_memory = model.current_patcher.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] diff --git a/comfy/sd.py b/comfy/sd.py index b71157343..f83164d23 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -797,7 +797,7 @@ class VAE: try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -871,7 +871,7 @@ class VAE: try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) samples = None