mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Quick fix unit tests.
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled
This commit is contained in:
parent
8fa3d1281e
commit
83a0fa107b
@ -2,6 +2,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
# Add comfy to path
|
# Add comfy to path
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
@ -15,6 +16,7 @@ if not has_gpu():
|
|||||||
|
|
||||||
from comfy import ops
|
from comfy import ops
|
||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
class SimpleModel(torch.nn.Module):
|
class SimpleModel(torch.nn.Module):
|
||||||
@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
@ -142,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
model.load_state_dict(state_dict1, strict=False)
|
model.load_state_dict(state_dict1, strict=False)
|
||||||
|
|
||||||
# Save state dict
|
# Save state dict
|
||||||
@ -179,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Add a weight function (simulating LoRA)
|
# Add a weight function (simulating LoRA)
|
||||||
@ -216,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
|
|
||||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user