added basic tests for sparse moe
This commit is contained in:
parent
10967ea880
commit
c0e36cd579
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
|
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe
|
||||||
|
|
||||||
def test_aiiabase_creation():
|
def test_aiiabase_creation():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
|
@ -106,6 +106,32 @@ def test_aiiamoe_save_load():
|
||||||
os.remove(os.path.join(save_path, "config.json"))
|
os.remove(os.path.join(save_path, "config.json"))
|
||||||
os.rmdir(save_path)
|
os.rmdir(save_path)
|
||||||
|
|
||||||
|
def test_aiiasparsemoe_creation():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIASparseMoe(config, num_experts=5, top_k=2)
|
||||||
|
assert isinstance(model, AIIASparseMoe)
|
||||||
|
|
||||||
|
def test_aiiasparsemoe_save_load():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIASparseMoe(config, num_experts=3, top_k=1)
|
||||||
|
save_path = "test_aiiasparsemoe_save_load"
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
model.save(save_path)
|
||||||
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||||
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
loaded_model = AIIASparseMoe.load(save_path)
|
||||||
|
|
||||||
|
# Check if the loaded model is an instance of AIIASparseMoe
|
||||||
|
assert isinstance(loaded_model, AIIASparseMoe)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
os.remove(os.path.join(save_path, "model.pth"))
|
||||||
|
os.remove(os.path.join(save_path, "config.json"))
|
||||||
|
os.rmdir(save_path)
|
||||||
|
|
||||||
def test_aiiachunked_creation():
|
def test_aiiachunked_creation():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
model = AIIAchunked(config)
|
model = AIIAchunked(config)
|
||||||
|
|
Loading…
Reference in New Issue