AIIA/tests/test_aiia.py

133 lines
3.9 KiB
Python

import os
import torch
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
def test_aiiabase_creation():
config = AIIAConfig()
model = AIIABase(config)
assert isinstance(model, AIIABase)
def test_aiiabase_save_load():
config = AIIAConfig()
model = AIIABase(config)
save_path = "test_aiiabase_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIABase.load(save_path)
# Check if the loaded model is an instance of AIIABase
assert isinstance(loaded_model, AIIABase)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiabase_shared_creation():
config = AIIAConfig()
model = AIIABaseShared(config)
assert isinstance(model, AIIABaseShared)
def test_aiiabase_shared_save_load():
config = AIIAConfig()
model = AIIABaseShared(config)
save_path = "test_aiiabase_shared_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIABaseShared.load(save_path)
# Check if the loaded model is an instance of AIIABaseShared
assert isinstance(loaded_model, AIIABaseShared)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiaexpert_creation():
config = AIIAConfig()
model = AIIAExpert(config)
assert isinstance(model, AIIAExpert)
def test_aiiaexpert_save_load():
config = AIIAConfig()
model = AIIAExpert(config)
save_path = "test_aiiaexpert_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAExpert.load(save_path)
# Check if the loaded model is an instance of AIIAExpert
assert isinstance(loaded_model, AIIAExpert)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiamoe_creation():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
assert isinstance(model, AIIAmoe)
def test_aiiamoe_save_load():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
save_path = "test_aiiamoe_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAmoe.load(save_path)
# Check if the loaded model is an instance of AIIAmoe
assert isinstance(loaded_model, AIIAmoe)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiachunked_creation():
config = AIIAConfig()
model = AIIAchunked(config)
assert isinstance(model, AIIAchunked)
def test_aiiachunked_save_load():
config = AIIAConfig()
model = AIIAchunked(config)
save_path = "test_aiiachunked_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAchunked.load(save_path)
# Check if the loaded model is an instance of AIIAchunked
assert isinstance(loaded_model, AIIAchunked)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)