diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 00d60ff72..b680de058 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1361,13 +1361,18 @@ class ModelPatcher: for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - def prepare_state(self, timestep, model_options, ignore_multigpu=False): + def prepare_state(self, timestep, model_options): + ignore_multigpu = model_options.get("ignore_multigpu", False) for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep, model_options, ignore_multigpu) + callback(self, timestep, model_options) if not ignore_multigpu and "multigpu_clones" in model_options: - for p in model_options["multigpu_clones"].values(): - p: ModelPatcher - p.prepare_state(timestep, model_options, ignore_multigpu=True) + model_options["ignore_multigpu"] = True + try: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p.prepare_state(timestep, model_options) + finally: + model_options.pop("ignore_multigpu", None) def restore_hook_patches(self): if self.hook_patches_backup is not None: