mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Fix tests.
This commit is contained in:
parent
cb2b4f282e
commit
0d6791be86
@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
def test_all_layers_standard(self):
|
def test_all_layers_standard(self):
|
||||||
"""Test that model with no quantization works normally"""
|
"""Test that model with no quantization works normally"""
|
||||||
# Configure no quantization
|
|
||||||
ops.MixedPrecisionOps._layer_quant_config = {}
|
|
||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
|
|
||||||
# Initialize weights manually
|
# Initialize weights manually
|
||||||
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
||||||
@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"params": {}
|
"params": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
|
||||||
|
|
||||||
# Create state dict with mixed precision
|
# Create state dict with mixed precision
|
||||||
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"params": {}
|
"params": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
|
||||||
|
|
||||||
# Create and load model
|
# Create and load model
|
||||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||||
model.load_state_dict(state_dict1, strict=False)
|
model.load_state_dict(state_dict1, strict=False)
|
||||||
|
|
||||||
# Save state dict
|
# Save state dict
|
||||||
@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"params": {}
|
"params": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
|
||||||
|
|
||||||
# Create and load model
|
# Create and load model
|
||||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Add a weight function (simulating LoRA)
|
# Add a weight function (simulating LoRA)
|
||||||
@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"params": {}
|
"params": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
|
||||||
|
|
||||||
# Create state dict
|
# Create state dict
|
||||||
state_dict = {
|
state_dict = {
|
||||||
@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user