AIIA/tests/model/test_aiia.py

133 lines
4.8 KiB
Python

import os
import torch
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAConfig, AIIASparseMoe
def test_aiiabase_creation():
config = AIIAConfig()
model = AIIABase(config)
assert isinstance(model, AIIABase)
def test_aiiabase_save_pretrained_from_pretrained():
config = AIIAConfig()
model = AIIABase(config)
save_pretrained_path = "test_aiiabase_save_pretrained_load"
# Save the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIABase.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIABase
assert isinstance(loaded_model, AIIABase)
# Clean up
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)
def test_aiiabase_shared_creation():
config = AIIAConfig()
model = AIIABaseShared(config)
assert isinstance(model, AIIABaseShared)
def test_aiiabase_shared_save_pretrained_from_pretrained():
config = AIIAConfig()
model = AIIABaseShared(config)
save_pretrained_path = "test_aiiabase_shared_save_pretrained_load"
# Save the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIABaseShared.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIABaseShared
assert isinstance(loaded_model, AIIABaseShared)
# Clean up
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)
def test_aiiaexpert_creation():
config = AIIAConfig()
model = AIIAExpert(config)
assert isinstance(model, AIIAExpert)
def test_aiiaexpert_save_pretrained_from_pretrained():
config = AIIAConfig()
model = AIIAExpert(config)
save_pretrained_path = "test_aiiaexpert_save_pretrained_load"
# Save the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIAExpert.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIAExpert
assert isinstance(loaded_model, AIIAExpert)
# Clean up
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)
def test_aiiamoe_creation():
config = AIIAConfig(num_experts=3)
model = AIIAmoe(config)
assert isinstance(model, AIIAmoe)
def test_aiiamoe_save_pretrained_from_pretrained():
config = AIIAConfig(num_experts=3)
model = AIIAmoe(config)
save_pretrained_path = "test_aiiamoe_save_pretrained_load"
# Save the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIAmoe.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIAmoe
assert isinstance(loaded_model, AIIAmoe)
# Clean up
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)
def test_aiiasparsemoe_creation():
config = AIIAConfig(num_experts=5, top_k=2)
model = AIIASparseMoe(config, base_class=AIIABaseShared)
assert isinstance(model, AIIASparseMoe)
def test_aiiasparsemoe_save_pretrained_from_pretrained():
config = AIIAConfig(num_experts=3, top_k=1)
model = AIIASparseMoe(config)
save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load"
# Save the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIASparseMoe.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIASparseMoe
assert isinstance(loaded_model, AIIASparseMoe)
# Clean up
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)