From 9ec01e2dd0f9a10d0f6c2f87992abf308a6f2bf0 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 13 Apr 2025 22:19:15 +0200 Subject: [PATCH] tests with new config --- tests/model/test_aiia.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/model/test_aiia.py b/tests/model/test_aiia.py index 6177298..4891472 100644 --- a/tests/model/test_aiia.py +++ b/tests/model/test_aiia.py @@ -12,7 +12,7 @@ def test_aiiabase_save_pretrained_from_pretrained(): model = AIIABase(config) save_pretrained_path = "test_aiiabase_save_pretrained_load" - # save_pretrained the model + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -38,7 +38,7 @@ def test_aiiabase_shared_save_pretrained_from_pretrained(): model = AIIABaseShared(config) save_pretrained_path = "test_aiiabase_shared_save_pretrained_load" - # save_pretrained the model + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -64,7 +64,7 @@ def test_aiiaexpert_save_pretrained_from_pretrained(): model = AIIAExpert(config) save_pretrained_path = "test_aiiaexpert_save_pretrained_load" - # save_pretrained the model + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -81,16 +81,16 @@ def test_aiiaexpert_save_pretrained_from_pretrained(): os.rmdir(save_pretrained_path) def test_aiiamoe_creation(): - config = AIIAConfig() - model = AIIAmoe(config, num_experts=5) + config = AIIAConfig(num_experts=3) + model = AIIAmoe(config) assert isinstance(model, AIIAmoe) def test_aiiamoe_save_pretrained_from_pretrained(): - config = AIIAConfig() - model = AIIAmoe(config, num_experts=5) + config = AIIAConfig(num_experts=3) + model = AIIAmoe(config) save_pretrained_path = "test_aiiamoe_save_pretrained_load" - # save_pretrained the model + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -107,16 +107,16 @@ def test_aiiamoe_save_pretrained_from_pretrained(): os.rmdir(save_pretrained_path) def test_aiiasparsemoe_creation(): - config = AIIAConfig() - model = AIIASparseMoe(config, num_experts=5, top_k=2) + config = AIIAConfig(num_experts=5, top_k=2) + model = AIIASparseMoe(config, base_class=AIIABaseShared) assert isinstance(model, AIIASparseMoe) def test_aiiasparsemoe_save_pretrained_from_pretrained(): - config = AIIAConfig() - model = AIIASparseMoe(config, num_experts=3, top_k=1) + config = AIIAConfig(num_experts=3, top_k=1) + model = AIIASparseMoe(config) save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load" - - # save_pretrained the model + + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -130,4 +130,4 @@ def test_aiiasparsemoe_save_pretrained_from_pretrained(): # Clean up os.remove(os.path.join(save_pretrained_path, "model.safetensors")) os.remove(os.path.join(save_pretrained_path, "config.json")) - os.rmdir(save_pretrained_path) + os.rmdir(save_pretrained_path) \ No newline at end of file