diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 087b0fbfa..2bb363fab 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1692,16 +1692,27 @@ class ModelPatcherDynamic(ModelPatcher): self.model.dynamic_vbars = {} if not hasattr(self.model, "dynamic_pins"): self.model.dynamic_pins = {} - if self.load_device not in self.model.dynamic_pins: - self.model.dynamic_pins[self.load_device] = { + self.register_load_device(self.load_device) + self.non_dynamic_delegate_model = None + assert load_device is not None + + def register_load_device(self, device): + """Ensure dynamic_pins has an entry for *device*. + + Called from __init__ and also from any code that retargets an + already-constructed patcher to a new load_device (e.g. the + Select{Model,CLIP,VAE}Device selector nodes); without this entry + partially_unload_ram() raises KeyError when it tries to read the + per-device pin state. + """ + if device not in self.model.dynamic_pins: + self.model.dynamic_pins[device] = { "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), "hostbufs_initialized": False, "failed": False, "active": False, } - self.non_dynamic_delegate_model = None - assert load_device is not None def is_dynamic(self): return True diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 9e03c56f0..df701af56 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -90,6 +90,8 @@ class SelectModelDeviceNode(io.ComfyNode): model.load_device = resolved if resolved.type == "cpu": model.offload_device = resolved + if hasattr(model, "register_load_device"): + model.register_load_device(resolved) return io.NodeOutput(model) @@ -135,6 +137,8 @@ class SelectCLIPDeviceNode(io.ComfyNode): clip.patcher.load_device = resolved if resolved.type == "cpu": clip.patcher.offload_device = resolved + if hasattr(clip.patcher, "register_load_device"): + clip.patcher.register_load_device(resolved) return io.NodeOutput(clip) @@ -185,6 +189,8 @@ class SelectVAEDeviceNode(io.ComfyNode): vae.device = resolved vae.patcher.load_device = resolved vae.patcher.offload_device = comfy.model_management.vae_offload_device() + if hasattr(vae.patcher, "register_load_device"): + vae.patcher.register_load_device(resolved) return io.NodeOutput(vae)