From 04099ef605c373ec82081f3137a23bd1926b67ae Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 19:53:30 -0500 Subject: [PATCH 1/5] Restore Trellis2 clip vision image_size state --- comfy_extras/nodes_trellis2.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 3479d5410..b1ad5d1e1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -256,6 +256,7 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() + original_image_size = getattr(model_internal, "image_size", None) def prepare_tensor(pil_img, size): resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) @@ -268,10 +269,13 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] cond_1024 = None - if include_1024: - model_internal.image_size = 1024 - input_1024 = prepare_tensor(cropped_img_tensor, 1024) - cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] + try: + if include_1024: + model_internal.image_size = 1024 + input_1024 = prepare_tensor(cropped_img_tensor, 1024) + cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] + finally: + model_internal.image_size = original_image_size conditioning = { 'cond_512': cond_512.to(device), From d7416e56906b9bc8280223fd22532364428fc716 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 20:49:42 -0500 Subject: [PATCH 2/5] Guard full Trellis2 conditioning image_size restore --- comfy_extras/nodes_trellis2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index b1ad5d1e1..c8ac9bc33 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -264,12 +264,12 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) - model_internal.image_size = 512 - input_512 = prepare_tensor(cropped_img_tensor, 512) - cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] - cond_1024 = None try: + model_internal.image_size = 512 + input_512 = prepare_tensor(cropped_img_tensor, 512) + cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] + if include_1024: model_internal.image_size = 1024 input_1024 = prepare_tensor(cropped_img_tensor, 1024) From 2ad1ca5531b96ad61cc4c80d81118d219b635afc Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 20:51:22 -0500 Subject: [PATCH 3/5] Handle missing Trellis2 image_size restore state --- comfy_extras/nodes_trellis2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c8ac9bc33..2b712d113 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -256,7 +256,8 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() - original_image_size = getattr(model_internal, "image_size", None) + image_size_missing = object() + original_image_size = getattr(model_internal, "image_size", image_size_missing) def prepare_tensor(pil_img, size): resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) @@ -275,7 +276,10 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): input_1024 = prepare_tensor(cropped_img_tensor, 1024) cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] finally: - model_internal.image_size = original_image_size + if original_image_size is image_size_missing: + delattr(model_internal, "image_size") + else: + model_internal.image_size = original_image_size conditioning = { 'cond_512': cond_512.to(device), From 7c6b237fe89d074510da0cd4f382a07829649547 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 21:05:07 -0500 Subject: [PATCH 4/5] Match Copilot image_size restore pattern --- comfy_extras/nodes_trellis2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 2b712d113..61d3532a1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -256,8 +256,8 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() - image_size_missing = object() - original_image_size = getattr(model_internal, "image_size", image_size_missing) + had_image_size = hasattr(model_internal, "image_size") + original_image_size = getattr(model_internal, "image_size", None) def prepare_tensor(pil_img, size): resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) @@ -276,7 +276,7 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): input_1024 = prepare_tensor(cropped_img_tensor, 1024) cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] finally: - if original_image_size is image_size_missing: + if not had_image_size: delattr(model_internal, "image_size") else: model_internal.image_size = original_image_size From cf3cfec5964afd91dd8404f2fb8ac7312ad458fa Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 21:11:58 -0500 Subject: [PATCH 5/5] Add Trellis2 image_size restore tests --- .../comfy_extras_test/nodes_trellis2_test.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 tests-unit/comfy_extras_test/nodes_trellis2_test.py diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py new file mode 100644 index 000000000..920eca471 --- /dev/null +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -0,0 +1,127 @@ +import importlib +import sys +import types +import unittest +from unittest.mock import patch + +import torch +from PIL import Image + + +class _DummyPort: + @staticmethod + def Input(*args, **kwargs): + return None + + @staticmethod + def Output(*args, **kwargs): + return None + + +class _DummyIO: + ComfyNode = object + + @staticmethod + def Schema(*args, **kwargs): + return None + + @staticmethod + def NodeOutput(*args, **kwargs): + return args + + def __getattr__(self, name): + return _DummyPort + + +class _DummyTypes: + def __getattr__(self, name): + return lambda *args, **kwargs: None + + +dummy_comfy_api_latest = types.SimpleNamespace( + ComfyExtension=object, + IO=_DummyIO(), + Types=_DummyTypes(), +) + +dummy_sparse_tensor = type("SparseTensor", (), {}) +dummy_trellis_vae = types.SimpleNamespace(SparseTensor=dummy_sparse_tensor) + +with patch.dict(sys.modules, { + "comfy_api.latest": dummy_comfy_api_latest, + "comfy.ldm.trellis2.vae": dummy_trellis_vae, +}): + nodes_trellis2 = importlib.import_module("comfy_extras.nodes_trellis2") + + +class DummyInnerModel: + def __init__(self, image_size=..., fail_on_call=None): + self.call_count = 0 + self.fail_on_call = fail_on_call + if image_size is not ...: + self.image_size = image_size + + def __call__(self, input_tensor, skip_norm_elementwise=True): + self.call_count += 1 + if self.fail_on_call == self.call_count: + raise RuntimeError("expected conditioning failure") + return (torch.ones((1, 4), dtype=torch.float32),) + + +class DummyModel: + def __init__(self, inner_model): + self.model = inner_model + + +class TestRunConditioningRestore(unittest.TestCase): + def setUp(self): + self.intermediate_patch = patch.object( + nodes_trellis2.comfy.model_management, "intermediate_device", lambda: "cpu" + ) + self.torch_device_patch = patch.object( + nodes_trellis2.comfy.model_management, "get_torch_device", lambda: "cpu" + ) + self.intermediate_patch.start() + self.torch_device_patch.start() + + def tearDown(self): + self.intermediate_patch.stop() + self.torch_device_patch.stop() + + @staticmethod + def make_test_image(): + return Image.new("RGB", (8, 8), color="white") + + def test_restores_existing_image_size_after_success(self): + inner_model = DummyInnerModel(image_size=777) + + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertEqual(inner_model.image_size, 777) + + def test_deletes_missing_image_size_after_success(self): + inner_model = DummyInnerModel() + + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertFalse(hasattr(inner_model, "image_size")) + + def test_restores_existing_image_size_after_512_failure(self): + inner_model = DummyInnerModel(image_size=777, fail_on_call=1) + + with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertEqual(inner_model.image_size, 777) + + def test_deletes_missing_image_size_after_1024_failure(self): + inner_model = DummyInnerModel(fail_on_call=2) + + with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertFalse(hasattr(inner_model, "image_size")) + + +if __name__ == "__main__": + unittest.main()