mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-29 11:52:45 +08:00
Merge pull request #8 from pollockjj/issue_89
Restore Trellis2 clip vision image_size state
This commit is contained in:
commit
dd8c5703d8
@ -256,6 +256,8 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
|||||||
model_internal = model.model
|
model_internal = model.model
|
||||||
device = comfy.model_management.intermediate_device()
|
device = comfy.model_management.intermediate_device()
|
||||||
torch_device = comfy.model_management.get_torch_device()
|
torch_device = comfy.model_management.get_torch_device()
|
||||||
|
had_image_size = hasattr(model_internal, "image_size")
|
||||||
|
original_image_size = getattr(model_internal, "image_size", None)
|
||||||
|
|
||||||
def prepare_tensor(pil_img, size):
|
def prepare_tensor(pil_img, size):
|
||||||
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
|
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
|
||||||
@ -263,15 +265,21 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
|||||||
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
||||||
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||||
|
|
||||||
model_internal.image_size = 512
|
|
||||||
input_512 = prepare_tensor(cropped_img_tensor, 512)
|
|
||||||
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
|
|
||||||
|
|
||||||
cond_1024 = None
|
cond_1024 = None
|
||||||
if include_1024:
|
try:
|
||||||
model_internal.image_size = 1024
|
model_internal.image_size = 512
|
||||||
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
|
input_512 = prepare_tensor(cropped_img_tensor, 512)
|
||||||
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
|
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
|
||||||
|
|
||||||
|
if include_1024:
|
||||||
|
model_internal.image_size = 1024
|
||||||
|
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
|
||||||
|
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
|
||||||
|
finally:
|
||||||
|
if not had_image_size:
|
||||||
|
delattr(model_internal, "image_size")
|
||||||
|
else:
|
||||||
|
model_internal.image_size = original_image_size
|
||||||
|
|
||||||
conditioning = {
|
conditioning = {
|
||||||
'cond_512': cond_512.to(device),
|
'cond_512': cond_512.to(device),
|
||||||
|
|||||||
127
tests-unit/comfy_extras_test/nodes_trellis2_test.py
Normal file
127
tests-unit/comfy_extras_test/nodes_trellis2_test.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyPort:
|
||||||
|
@staticmethod
|
||||||
|
def Input(*args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Output(*args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyIO:
|
||||||
|
ComfyNode = object
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Schema(*args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def NodeOutput(*args, **kwargs):
|
||||||
|
return args
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return _DummyPort
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyTypes:
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
|
||||||
|
dummy_comfy_api_latest = types.SimpleNamespace(
|
||||||
|
ComfyExtension=object,
|
||||||
|
IO=_DummyIO(),
|
||||||
|
Types=_DummyTypes(),
|
||||||
|
)
|
||||||
|
|
||||||
|
dummy_sparse_tensor = type("SparseTensor", (), {})
|
||||||
|
dummy_trellis_vae = types.SimpleNamespace(SparseTensor=dummy_sparse_tensor)
|
||||||
|
|
||||||
|
with patch.dict(sys.modules, {
|
||||||
|
"comfy_api.latest": dummy_comfy_api_latest,
|
||||||
|
"comfy.ldm.trellis2.vae": dummy_trellis_vae,
|
||||||
|
}):
|
||||||
|
nodes_trellis2 = importlib.import_module("comfy_extras.nodes_trellis2")
|
||||||
|
|
||||||
|
|
||||||
|
class DummyInnerModel:
|
||||||
|
def __init__(self, image_size=..., fail_on_call=None):
|
||||||
|
self.call_count = 0
|
||||||
|
self.fail_on_call = fail_on_call
|
||||||
|
if image_size is not ...:
|
||||||
|
self.image_size = image_size
|
||||||
|
|
||||||
|
def __call__(self, input_tensor, skip_norm_elementwise=True):
|
||||||
|
self.call_count += 1
|
||||||
|
if self.fail_on_call == self.call_count:
|
||||||
|
raise RuntimeError("expected conditioning failure")
|
||||||
|
return (torch.ones((1, 4), dtype=torch.float32),)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModel:
|
||||||
|
def __init__(self, inner_model):
|
||||||
|
self.model = inner_model
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunConditioningRestore(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.intermediate_patch = patch.object(
|
||||||
|
nodes_trellis2.comfy.model_management, "intermediate_device", lambda: "cpu"
|
||||||
|
)
|
||||||
|
self.torch_device_patch = patch.object(
|
||||||
|
nodes_trellis2.comfy.model_management, "get_torch_device", lambda: "cpu"
|
||||||
|
)
|
||||||
|
self.intermediate_patch.start()
|
||||||
|
self.torch_device_patch.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.intermediate_patch.stop()
|
||||||
|
self.torch_device_patch.stop()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_test_image():
|
||||||
|
return Image.new("RGB", (8, 8), color="white")
|
||||||
|
|
||||||
|
def test_restores_existing_image_size_after_success(self):
|
||||||
|
inner_model = DummyInnerModel(image_size=777)
|
||||||
|
|
||||||
|
nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True)
|
||||||
|
|
||||||
|
self.assertEqual(inner_model.image_size, 777)
|
||||||
|
|
||||||
|
def test_deletes_missing_image_size_after_success(self):
|
||||||
|
inner_model = DummyInnerModel()
|
||||||
|
|
||||||
|
nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True)
|
||||||
|
|
||||||
|
self.assertFalse(hasattr(inner_model, "image_size"))
|
||||||
|
|
||||||
|
def test_restores_existing_image_size_after_512_failure(self):
|
||||||
|
inner_model = DummyInnerModel(image_size=777, fail_on_call=1)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"):
|
||||||
|
nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True)
|
||||||
|
|
||||||
|
self.assertEqual(inner_model.image_size, 777)
|
||||||
|
|
||||||
|
def test_deletes_missing_image_size_after_1024_failure(self):
|
||||||
|
inner_model = DummyInnerModel(fail_on_call=2)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"):
|
||||||
|
nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True)
|
||||||
|
|
||||||
|
self.assertFalse(hasattr(inner_model, "image_size"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue
Block a user