From b1bcaaf8fe299fdda181c3d551b12f7566227204 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 12 Apr 2026 21:08:35 -0500 Subject: [PATCH] fix(isolation): expose inner model state_dict in isolation --- comfy/isolation/model_patcher_proxy.py | 2 + .../isolation/test_inner_model_state_dict.py | 45 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 tests/isolation/test_inner_model_state_dict.py diff --git a/comfy/isolation/model_patcher_proxy.py b/comfy/isolation/model_patcher_proxy.py index f44de1d5a..a78c19acd 100644 --- a/comfy/isolation/model_patcher_proxy.py +++ b/comfy/isolation/model_patcher_proxy.py @@ -885,4 +885,6 @@ class _InnerModelProxy: return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k) if name == "diffusion_model": return self._parent._call_rpc("get_inner_model_attr", "diffusion_model") + if name == "state_dict": + return lambda: self._parent.model_state_dict() raise AttributeError(f"'{name}' not supported on isolated InnerModel") diff --git a/tests/isolation/test_inner_model_state_dict.py b/tests/isolation/test_inner_model_state_dict.py new file mode 100644 index 000000000..c170118ac --- /dev/null +++ b/tests/isolation/test_inner_model_state_dict.py @@ -0,0 +1,45 @@ +"""Test that _InnerModelProxy exposes state_dict for LoRA loading.""" +import sys +from pathlib import Path +from unittest.mock import MagicMock + +repo_root = Path(__file__).resolve().parents[2] +pyisolate_root = repo_root.parent / "pyisolate" +if pyisolate_root.exists(): + sys.path.insert(0, str(pyisolate_root)) + +from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + +def test_inner_model_proxy_state_dict_returns_keys(): + """_InnerModelProxy.state_dict() delegates to parent.model_state_dict().""" + proxy = object.__new__(ModelPatcherProxy) + proxy._model_id = "test_model" + proxy._rpc = MagicMock() + proxy._model_type_name = "SDXL" + proxy._inner_model_channels = None + + fake_keys = ["diffusion_model.input.weight", "diffusion_model.output.weight"] + proxy._call_rpc = MagicMock(return_value=fake_keys) + + inner = proxy.model + sd = inner.state_dict() + + assert isinstance(sd, dict) + assert "diffusion_model.input.weight" in sd + assert "diffusion_model.output.weight" in sd + proxy._call_rpc.assert_called_with("model_state_dict", None) + + +def test_inner_model_proxy_state_dict_callable(): + """state_dict is a callable, not a property — matches torch.nn.Module interface.""" + proxy = object.__new__(ModelPatcherProxy) + proxy._model_id = "test_model" + proxy._rpc = MagicMock() + proxy._model_type_name = "SDXL" + proxy._inner_model_channels = None + + proxy._call_rpc = MagicMock(return_value=[]) + + inner = proxy.model + assert callable(inner.state_dict)