diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 3479d5410..61d3532a1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -256,6 +256,8 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() + had_image_size = hasattr(model_internal, "image_size") + original_image_size = getattr(model_internal, "image_size", None) def prepare_tensor(pil_img, size): resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) @@ -263,15 +265,21 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) - model_internal.image_size = 512 - input_512 = prepare_tensor(cropped_img_tensor, 512) - cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] - cond_1024 = None - if include_1024: - model_internal.image_size = 1024 - input_1024 = prepare_tensor(cropped_img_tensor, 1024) - cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] + try: + model_internal.image_size = 512 + input_512 = prepare_tensor(cropped_img_tensor, 512) + cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] + + if include_1024: + model_internal.image_size = 1024 + input_1024 = prepare_tensor(cropped_img_tensor, 1024) + cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] + finally: + if not had_image_size: + delattr(model_internal, "image_size") + else: + model_internal.image_size = original_image_size conditioning = { 'cond_512': cond_512.to(device), diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py new file mode 100644 index 000000000..920eca471 --- /dev/null +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -0,0 +1,127 @@ +import importlib +import sys +import types +import unittest +from unittest.mock import patch + +import torch +from PIL import Image + + +class _DummyPort: + @staticmethod + def Input(*args, **kwargs): + return None + + @staticmethod + def Output(*args, **kwargs): + return None + + +class _DummyIO: + ComfyNode = object + + @staticmethod + def Schema(*args, **kwargs): + return None + + @staticmethod + def NodeOutput(*args, **kwargs): + return args + + def __getattr__(self, name): + return _DummyPort + + +class _DummyTypes: + def __getattr__(self, name): + return lambda *args, **kwargs: None + + +dummy_comfy_api_latest = types.SimpleNamespace( + ComfyExtension=object, + IO=_DummyIO(), + Types=_DummyTypes(), +) + +dummy_sparse_tensor = type("SparseTensor", (), {}) +dummy_trellis_vae = types.SimpleNamespace(SparseTensor=dummy_sparse_tensor) + +with patch.dict(sys.modules, { + "comfy_api.latest": dummy_comfy_api_latest, + "comfy.ldm.trellis2.vae": dummy_trellis_vae, +}): + nodes_trellis2 = importlib.import_module("comfy_extras.nodes_trellis2") + + +class DummyInnerModel: + def __init__(self, image_size=..., fail_on_call=None): + self.call_count = 0 + self.fail_on_call = fail_on_call + if image_size is not ...: + self.image_size = image_size + + def __call__(self, input_tensor, skip_norm_elementwise=True): + self.call_count += 1 + if self.fail_on_call == self.call_count: + raise RuntimeError("expected conditioning failure") + return (torch.ones((1, 4), dtype=torch.float32),) + + +class DummyModel: + def __init__(self, inner_model): + self.model = inner_model + + +class TestRunConditioningRestore(unittest.TestCase): + def setUp(self): + self.intermediate_patch = patch.object( + nodes_trellis2.comfy.model_management, "intermediate_device", lambda: "cpu" + ) + self.torch_device_patch = patch.object( + nodes_trellis2.comfy.model_management, "get_torch_device", lambda: "cpu" + ) + self.intermediate_patch.start() + self.torch_device_patch.start() + + def tearDown(self): + self.intermediate_patch.stop() + self.torch_device_patch.stop() + + @staticmethod + def make_test_image(): + return Image.new("RGB", (8, 8), color="white") + + def test_restores_existing_image_size_after_success(self): + inner_model = DummyInnerModel(image_size=777) + + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertEqual(inner_model.image_size, 777) + + def test_deletes_missing_image_size_after_success(self): + inner_model = DummyInnerModel() + + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertFalse(hasattr(inner_model, "image_size")) + + def test_restores_existing_image_size_after_512_failure(self): + inner_model = DummyInnerModel(image_size=777, fail_on_call=1) + + with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertEqual(inner_model.image_size, 777) + + def test_deletes_missing_image_size_after_1024_failure(self): + inner_model = DummyInnerModel(fail_on_call=2) + + with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertFalse(hasattr(inner_model, "image_size")) + + +if __name__ == "__main__": + unittest.main()