import importlib import sys import types import unittest from unittest.mock import patch import torch from PIL import Image class _DummyPort: @staticmethod def Input(*args, **kwargs): return None @staticmethod def Output(*args, **kwargs): return None class _DummyIO: ComfyNode = object @staticmethod def Schema(*args, **kwargs): return None @staticmethod def NodeOutput(*args, **kwargs): return args def __getattr__(self, name): return _DummyPort class _DummyTypes: def __getattr__(self, name): return lambda *args, **kwargs: None dummy_comfy_api_latest = types.SimpleNamespace( ComfyExtension=object, IO=_DummyIO(), Types=_DummyTypes(), ) dummy_sparse_tensor = type("SparseTensor", (), {}) dummy_trellis_vae = types.SimpleNamespace(SparseTensor=dummy_sparse_tensor) with patch.dict(sys.modules, { "comfy_api.latest": dummy_comfy_api_latest, "comfy.ldm.trellis2.vae": dummy_trellis_vae, }): nodes_trellis2 = importlib.import_module("comfy_extras.nodes_trellis2") class DummyInnerModel: def __init__(self, image_size=..., fail_on_call=None): self.call_count = 0 self.fail_on_call = fail_on_call if image_size is not ...: self.image_size = image_size def __call__(self, input_tensor, skip_norm_elementwise=True): self.call_count += 1 if self.fail_on_call == self.call_count: raise RuntimeError("expected conditioning failure") return (torch.ones((1, 4), dtype=torch.float32),) class DummyModel: def __init__(self, inner_model): self.model = inner_model class TestRunConditioningRestore(unittest.TestCase): def setUp(self): self.intermediate_patch = patch.object( nodes_trellis2.comfy.model_management, "intermediate_device", lambda: "cpu" ) self.torch_device_patch = patch.object( nodes_trellis2.comfy.model_management, "get_torch_device", lambda: "cpu" ) self.intermediate_patch.start() self.torch_device_patch.start() def tearDown(self): self.intermediate_patch.stop() self.torch_device_patch.stop() @staticmethod def make_test_image(): return Image.new("RGB", (8, 8), color="white") def test_restores_existing_image_size_after_success(self): inner_model = DummyInnerModel(image_size=777) nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) self.assertEqual(inner_model.image_size, 777) def test_deletes_missing_image_size_after_success(self): inner_model = DummyInnerModel() nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) self.assertFalse(hasattr(inner_model, "image_size")) def test_restores_existing_image_size_after_512_failure(self): inner_model = DummyInnerModel(image_size=777, fail_on_call=1) with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) self.assertEqual(inner_model.image_size, 777) def test_deletes_missing_image_size_after_1024_failure(self): inner_model = DummyInnerModel(fail_on_call=2) with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) self.assertFalse(hasattr(inner_model, "image_size")) class DummyCloneModel: def __init__(self): self.model_options = {} def clone(self): cloned = DummyCloneModel() cloned.model_options = self.model_options.copy() return cloned class TestTrellisBatchSemantics(unittest.TestCase): def test_empty_structure_latent_is_deterministic_and_propagates_sample_indices(self): batch_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(2, 0, 17)[0] single_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(1, 5, 17)[0] expected_batch = torch.zeros(2, 8, 16, 16, 16) expected_batch[0] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(17))[0] expected_batch[1] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(18))[0] expected_single = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(22)) self.assertTrue(torch.equal(batch_output["samples"], expected_batch)) self.assertEqual(batch_output["batch_index"], [0, 1]) self.assertTrue(torch.equal(single_output["samples"], expected_single)) self.assertEqual(single_output["batch_index"], [5]) def test_empty_shape_latent_is_deterministic_and_propagates_batch_index(self): coords = torch.tensor( [ [1, 5, 5, 5], [0, 1, 1, 1], [1, 6, 6, 6], [0, 2, 2, 2], [1, 7, 7, 7], ], dtype=torch.int32, ) structure = { "coords": coords, "coord_counts": torch.tensor([2, 3], dtype=torch.int64), "batch_index": [4, 9], } output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 23) expected = torch.zeros(2, 32, 3, 1) expected[0, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(27))[0] expected[1, :, :3, :] = torch.randn(1, 32, 3, 1, generator=torch.Generator(device="cpu").manual_seed(32))[0] self.assertTrue(torch.equal(output["samples"], expected)) self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2, 3], dtype=torch.int64))) self.assertEqual(output["batch_index"], [4, 9]) def test_empty_shape_latent_keeps_singleton_coord_counts(self): structure = { "coords": torch.tensor( [ [0, 1, 1, 1], [0, 2, 2, 2], ], dtype=torch.int32, ), } output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2], dtype=torch.int64))) def test_flatten_batched_sparse_latent_validates_coord_counts(self): samples = torch.zeros(2, 32, 3, 1) coords = torch.tensor( [ [0, 1, 1, 1], [1, 2, 2, 2], [1, 3, 3, 3], ], dtype=torch.int32, ) coord_counts = torch.tensor([2, 1], dtype=torch.int64) with self.assertRaises(ValueError): nodes_trellis2.flatten_batched_sparse_latent(samples, coords, coord_counts) if __name__ == "__main__": unittest.main()