added basic tests for sparse moe

This commit is contained in:
Falko Victor Habel 2025-03-26 21:26:14 +01:00
parent 10967ea880
commit c0e36cd579
1 changed files with 27 additions and 1 deletions

View File

@ -1,6 +1,6 @@
import os
import torch
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe
def test_aiiabase_creation():
config = AIIAConfig()
@ -106,6 +106,32 @@ def test_aiiamoe_save_load():
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiasparsemoe_creation():
config = AIIAConfig()
model = AIIASparseMoe(config, num_experts=5, top_k=2)
assert isinstance(model, AIIASparseMoe)
def test_aiiasparsemoe_save_load():
config = AIIAConfig()
model = AIIASparseMoe(config, num_experts=3, top_k=1)
save_path = "test_aiiasparsemoe_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIASparseMoe.load(save_path)
# Check if the loaded model is an instance of AIIASparseMoe
assert isinstance(loaded_model, AIIASparseMoe)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiachunked_creation():
config = AIIAConfig()
model = AIIAchunked(config)