Quick fix unit tests.
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled

This commit is contained in:
comfyanonymous 2025-12-03 20:11:28 -05:00
parent 8fa3d1281e
commit 83a0fa107b

View File

@ -2,6 +2,7 @@ import unittest
import torch
import sys
import os
import json
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
@ -15,6 +16,7 @@ if not has_gpu():
from comfy import ops
from comfy.quant_ops import QuantizedTensor
import comfy.utils
class SimpleModel(torch.nn.Module):
@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
}
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
# Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model = SimpleModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor
@ -142,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
model = SimpleModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict1, strict=False)
# Save state dict
@ -179,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
model = SimpleModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA)
@ -216,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model = SimpleModel(operations=ops.mixed_precision_ops({}))
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)