Restore prepare_state backward-compatible signature

Drop the new ignore_multigpu positional argument from prepare_state and
from the ON_PREPARE_STATE callbacks; pass the flag via model_options
instead. This restores the original 3-arg callback signature so existing
custom-node ON_PREPARE_STATE handlers keep working unchanged, while
still letting prepare_state's recursive call into multigpu_clones
short-circuit.

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Kosinkadink 2026-05-21 11:35:39 -07:00
parent 4d9106dced
commit adde1239b1

View File

@ -1361,13 +1361,18 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self) callback(self)
def prepare_state(self, timestep, model_options, ignore_multigpu=False): def prepare_state(self, timestep, model_options):
ignore_multigpu = model_options.get("ignore_multigpu", False)
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep, model_options, ignore_multigpu) callback(self, timestep, model_options)
if not ignore_multigpu and "multigpu_clones" in model_options: if not ignore_multigpu and "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values(): model_options["ignore_multigpu"] = True
p: ModelPatcher try:
p.prepare_state(timestep, model_options, ignore_multigpu=True) for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options)
finally:
model_options.pop("ignore_multigpu", None)
def restore_hook_patches(self): def restore_hook_patches(self):
if self.hook_patches_backup is not None: if self.hook_patches_backup is not None: