diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 63361309f..5f45bfeef 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -115,7 +115,8 @@ class TestMixedPrecisionOps(unittest.TestCase): # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output = model(input_tensor) + with torch.inference_mode(): + output = model(input_tensor) self.assertEqual(output.shape, (5, 40))