133 lines
3.9 KiB
Python
133 lines
3.9 KiB
Python
import os
|
|
import torch
|
|
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
|
|
|
|
def test_aiiabase_creation():
|
|
config = AIIAConfig()
|
|
model = AIIABase(config)
|
|
assert isinstance(model, AIIABase)
|
|
|
|
def test_aiiabase_save_load():
|
|
config = AIIAConfig()
|
|
model = AIIABase(config)
|
|
save_path = "test_aiiabase_save_load"
|
|
|
|
# Save the model
|
|
model.save(save_path)
|
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
|
|
|
# Load the model
|
|
loaded_model = AIIABase.load(save_path)
|
|
|
|
# Check if the loaded model is an instance of AIIABase
|
|
assert isinstance(loaded_model, AIIABase)
|
|
|
|
# Clean up
|
|
os.remove(os.path.join(save_path, "model.pth"))
|
|
os.remove(os.path.join(save_path, "config.json"))
|
|
os.rmdir(save_path)
|
|
|
|
def test_aiiabase_shared_creation():
|
|
config = AIIAConfig()
|
|
model = AIIABaseShared(config)
|
|
assert isinstance(model, AIIABaseShared)
|
|
|
|
def test_aiiabase_shared_save_load():
|
|
config = AIIAConfig()
|
|
model = AIIABaseShared(config)
|
|
save_path = "test_aiiabase_shared_save_load"
|
|
|
|
# Save the model
|
|
model.save(save_path)
|
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
|
|
|
# Load the model
|
|
loaded_model = AIIABaseShared.load(save_path)
|
|
|
|
# Check if the loaded model is an instance of AIIABaseShared
|
|
assert isinstance(loaded_model, AIIABaseShared)
|
|
|
|
# Clean up
|
|
os.remove(os.path.join(save_path, "model.pth"))
|
|
os.remove(os.path.join(save_path, "config.json"))
|
|
os.rmdir(save_path)
|
|
|
|
def test_aiiaexpert_creation():
|
|
config = AIIAConfig()
|
|
model = AIIAExpert(config)
|
|
assert isinstance(model, AIIAExpert)
|
|
|
|
def test_aiiaexpert_save_load():
|
|
config = AIIAConfig()
|
|
model = AIIAExpert(config)
|
|
save_path = "test_aiiaexpert_save_load"
|
|
|
|
# Save the model
|
|
model.save(save_path)
|
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
|
|
|
# Load the model
|
|
loaded_model = AIIAExpert.load(save_path)
|
|
|
|
# Check if the loaded model is an instance of AIIAExpert
|
|
assert isinstance(loaded_model, AIIAExpert)
|
|
|
|
# Clean up
|
|
os.remove(os.path.join(save_path, "model.pth"))
|
|
os.remove(os.path.join(save_path, "config.json"))
|
|
os.rmdir(save_path)
|
|
|
|
def test_aiiamoe_creation():
|
|
config = AIIAConfig()
|
|
model = AIIAmoe(config, num_experts=5)
|
|
assert isinstance(model, AIIAmoe)
|
|
|
|
def test_aiiamoe_save_load():
|
|
config = AIIAConfig()
|
|
model = AIIAmoe(config, num_experts=5)
|
|
save_path = "test_aiiamoe_save_load"
|
|
|
|
# Save the model
|
|
model.save(save_path)
|
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
|
|
|
# Load the model
|
|
loaded_model = AIIAmoe.load(save_path)
|
|
|
|
# Check if the loaded model is an instance of AIIAmoe
|
|
assert isinstance(loaded_model, AIIAmoe)
|
|
|
|
# Clean up
|
|
os.remove(os.path.join(save_path, "model.pth"))
|
|
os.remove(os.path.join(save_path, "config.json"))
|
|
os.rmdir(save_path)
|
|
|
|
def test_aiiachunked_creation():
|
|
config = AIIAConfig()
|
|
model = AIIAchunked(config)
|
|
assert isinstance(model, AIIAchunked)
|
|
|
|
def test_aiiachunked_save_load():
|
|
config = AIIAConfig()
|
|
model = AIIAchunked(config)
|
|
save_path = "test_aiiachunked_save_load"
|
|
|
|
# Save the model
|
|
model.save(save_path)
|
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
|
|
|
# Load the model
|
|
loaded_model = AIIAchunked.load(save_path)
|
|
|
|
# Check if the loaded model is an instance of AIIAchunked
|
|
assert isinstance(loaded_model, AIIAchunked)
|
|
|
|
# Clean up
|
|
os.remove(os.path.join(save_path, "model.pth"))
|
|
os.remove(os.path.join(save_path, "config.json"))
|
|
os.rmdir(save_path) |