From c0e36cd5799b69484f401a22ed94dc515efd9880 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 26 Mar 2025 21:26:14 +0100 Subject: [PATCH] added basic tests for sparse moe --- tests/model/test_aiia.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/model/test_aiia.py b/tests/model/test_aiia.py index b8904a2..340537b 100644 --- a/tests/model/test_aiia.py +++ b/tests/model/test_aiia.py @@ -1,6 +1,6 @@ import os import torch -from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig +from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe def test_aiiabase_creation(): config = AIIAConfig() @@ -106,6 +106,32 @@ def test_aiiamoe_save_load(): os.remove(os.path.join(save_path, "config.json")) os.rmdir(save_path) +def test_aiiasparsemoe_creation(): + config = AIIAConfig() + model = AIIASparseMoe(config, num_experts=5, top_k=2) + assert isinstance(model, AIIASparseMoe) + +def test_aiiasparsemoe_save_load(): + config = AIIAConfig() + model = AIIASparseMoe(config, num_experts=3, top_k=1) + save_path = "test_aiiasparsemoe_save_load" + + # Save the model + model.save(save_path) + assert os.path.exists(os.path.join(save_path, "model.pth")) + assert os.path.exists(os.path.join(save_path, "config.json")) + + # Load the model + loaded_model = AIIASparseMoe.load(save_path) + + # Check if the loaded model is an instance of AIIASparseMoe + assert isinstance(loaded_model, AIIASparseMoe) + + # Clean up + os.remove(os.path.join(save_path, "model.pth")) + os.remove(os.path.join(save_path, "config.json")) + os.rmdir(save_path) + def test_aiiachunked_creation(): config = AIIAConfig() model = AIIAchunked(config)