Loop all samplers/schedulers in test_inference.py

This commit is contained in:
enzymezoo-code 2023-09-05 15:45:35 -05:00
parent 10dc6353f0
commit ad6cd712ab

View File

@ -141,8 +141,8 @@ SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
sampler_list = [SAMPLERS[0]] sampler_list = SAMPLERS
scheduler_list = [SCHEDULERS[0]] scheduler_list = SCHEDULERS
@pytest.mark.inference @pytest.mark.inference
@pytest.mark.parametrize("sampler", sampler_list) @pytest.mark.parametrize("sampler", sampler_list)
@pytest.mark.parametrize("scheduler", scheduler_list) @pytest.mark.parametrize("scheduler", scheduler_list)