added tests #29

Merged
Fabel merged 1 commits from added_tests into develop 2025-03-16 11:26:13 +00:00
2 changed files with 208 additions and 0 deletions

133
tests/test_aiia.py Normal file
View File

@ -0,0 +1,133 @@
import os
import torch
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
def test_aiiabase_creation():
config = AIIAConfig()
model = AIIABase(config)
assert isinstance(model, AIIABase)
def test_aiiabase_save_load():
config = AIIAConfig()
model = AIIABase(config)
save_path = "test_aiiabase_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIABase.load(save_path)
# Check if the loaded model is an instance of AIIABase
assert isinstance(loaded_model, AIIABase)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiabase_shared_creation():
config = AIIAConfig()
model = AIIABaseShared(config)
assert isinstance(model, AIIABaseShared)
def test_aiiabase_shared_save_load():
config = AIIAConfig()
model = AIIABaseShared(config)
save_path = "test_aiiabase_shared_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIABaseShared.load(save_path)
# Check if the loaded model is an instance of AIIABaseShared
assert isinstance(loaded_model, AIIABaseShared)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiaexpert_creation():
config = AIIAConfig()
model = AIIAExpert(config)
assert isinstance(model, AIIAExpert)
def test_aiiaexpert_save_load():
config = AIIAConfig()
model = AIIAExpert(config)
save_path = "test_aiiaexpert_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAExpert.load(save_path)
# Check if the loaded model is an instance of AIIAExpert
assert isinstance(loaded_model, AIIAExpert)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiamoe_creation():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
assert isinstance(model, AIIAmoe)
def test_aiiamoe_save_load():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
save_path = "test_aiiamoe_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAmoe.load(save_path)
# Check if the loaded model is an instance of AIIAmoe
assert isinstance(loaded_model, AIIAmoe)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiachunked_creation():
config = AIIAConfig()
model = AIIAchunked(config)
assert isinstance(model, AIIAchunked)
def test_aiiachunked_save_load():
config = AIIAConfig()
model = AIIAchunked(config)
save_path = "test_aiiachunked_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAchunked.load(save_path)
# Check if the loaded model is an instance of AIIAchunked
assert isinstance(loaded_model, AIIAchunked)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)

75
tests/test_config.py Normal file
View File

@ -0,0 +1,75 @@
import os
import tempfile
import pytest
import torch.nn as nn
from aiia import AIIAConfig
def test_aiia_config_initialization():
config = AIIAConfig()
assert config.model_name == "AIIA"
assert config.kernel_size == 3
assert config.activation_function == "GELU"
assert config.hidden_size == 512
assert config.num_hidden_layers == 12
assert config.num_channels == 3
assert config.learning_rate == 5e-5
def test_aiia_config_custom_initialization():
config = AIIAConfig(
model_name="CustomModel",
kernel_size=5,
activation_function="ReLU",
hidden_size=1024,
num_hidden_layers=8,
num_channels=1,
learning_rate=1e-4
)
assert config.model_name == "CustomModel"
assert config.kernel_size == 5
assert config.activation_function == "ReLU"
assert config.hidden_size == 1024
assert config.num_hidden_layers == 8
assert config.num_channels == 1
assert config.learning_rate == 1e-4
def test_aiia_config_invalid_activation_function():
with pytest.raises(ValueError):
AIIAConfig(activation_function="InvalidFunction")
def test_aiia_config_to_dict():
config = AIIAConfig()
config_dict = config.to_dict()
assert isinstance(config_dict, dict)
assert config_dict["model_name"] == "AIIA"
assert config_dict["kernel_size"] == 3
def test_aiia_config_save_and_load():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
assert loaded_config.kernel_size == 3
assert loaded_config.activation_function == "GELU"
def test_aiia_config_save_and_load_with_custom_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", custom_attr="value")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
assert loaded_config.custom_attr == "value"
def test_aiia_config_save_and_load_with_nested_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", nested={"key": "value"})
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
assert loaded_config.nested == {"key": "value"}