From da93db23c0e6b471849bf2e93da260fb665351b0 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 15 Mar 2025 18:48:27 +0100 Subject: [PATCH] added tests --- tests/test_aiia.py | 133 +++++++++++++++++++++++++++++++++++++++++++ tests/test_config.py | 75 ++++++++++++++++++++++++ 2 files changed, 208 insertions(+) create mode 100644 tests/test_aiia.py create mode 100644 tests/test_config.py diff --git a/tests/test_aiia.py b/tests/test_aiia.py new file mode 100644 index 0000000..b8904a2 --- /dev/null +++ b/tests/test_aiia.py @@ -0,0 +1,133 @@ +import os +import torch +from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig + +def test_aiiabase_creation(): + config = AIIAConfig() + model = AIIABase(config) + assert isinstance(model, AIIABase) + +def test_aiiabase_save_load(): + config = AIIAConfig() + model = AIIABase(config) + save_path = "test_aiiabase_save_load" + + # Save the model + model.save(save_path) + assert os.path.exists(os.path.join(save_path, "model.pth")) + assert os.path.exists(os.path.join(save_path, "config.json")) + + # Load the model + loaded_model = AIIABase.load(save_path) + + # Check if the loaded model is an instance of AIIABase + assert isinstance(loaded_model, AIIABase) + + # Clean up + os.remove(os.path.join(save_path, "model.pth")) + os.remove(os.path.join(save_path, "config.json")) + os.rmdir(save_path) + +def test_aiiabase_shared_creation(): + config = AIIAConfig() + model = AIIABaseShared(config) + assert isinstance(model, AIIABaseShared) + +def test_aiiabase_shared_save_load(): + config = AIIAConfig() + model = AIIABaseShared(config) + save_path = "test_aiiabase_shared_save_load" + + # Save the model + model.save(save_path) + assert os.path.exists(os.path.join(save_path, "model.pth")) + assert os.path.exists(os.path.join(save_path, "config.json")) + + # Load the model + loaded_model = AIIABaseShared.load(save_path) + + # Check if the loaded model is an instance of AIIABaseShared + assert isinstance(loaded_model, AIIABaseShared) + + # Clean up + os.remove(os.path.join(save_path, "model.pth")) + os.remove(os.path.join(save_path, "config.json")) + os.rmdir(save_path) + +def test_aiiaexpert_creation(): + config = AIIAConfig() + model = AIIAExpert(config) + assert isinstance(model, AIIAExpert) + +def test_aiiaexpert_save_load(): + config = AIIAConfig() + model = AIIAExpert(config) + save_path = "test_aiiaexpert_save_load" + + # Save the model + model.save(save_path) + assert os.path.exists(os.path.join(save_path, "model.pth")) + assert os.path.exists(os.path.join(save_path, "config.json")) + + # Load the model + loaded_model = AIIAExpert.load(save_path) + + # Check if the loaded model is an instance of AIIAExpert + assert isinstance(loaded_model, AIIAExpert) + + # Clean up + os.remove(os.path.join(save_path, "model.pth")) + os.remove(os.path.join(save_path, "config.json")) + os.rmdir(save_path) + +def test_aiiamoe_creation(): + config = AIIAConfig() + model = AIIAmoe(config, num_experts=5) + assert isinstance(model, AIIAmoe) + +def test_aiiamoe_save_load(): + config = AIIAConfig() + model = AIIAmoe(config, num_experts=5) + save_path = "test_aiiamoe_save_load" + + # Save the model + model.save(save_path) + assert os.path.exists(os.path.join(save_path, "model.pth")) + assert os.path.exists(os.path.join(save_path, "config.json")) + + # Load the model + loaded_model = AIIAmoe.load(save_path) + + # Check if the loaded model is an instance of AIIAmoe + assert isinstance(loaded_model, AIIAmoe) + + # Clean up + os.remove(os.path.join(save_path, "model.pth")) + os.remove(os.path.join(save_path, "config.json")) + os.rmdir(save_path) + +def test_aiiachunked_creation(): + config = AIIAConfig() + model = AIIAchunked(config) + assert isinstance(model, AIIAchunked) + +def test_aiiachunked_save_load(): + config = AIIAConfig() + model = AIIAchunked(config) + save_path = "test_aiiachunked_save_load" + + # Save the model + model.save(save_path) + assert os.path.exists(os.path.join(save_path, "model.pth")) + assert os.path.exists(os.path.join(save_path, "config.json")) + + # Load the model + loaded_model = AIIAchunked.load(save_path) + + # Check if the loaded model is an instance of AIIAchunked + assert isinstance(loaded_model, AIIAchunked) + + # Clean up + os.remove(os.path.join(save_path, "model.pth")) + os.remove(os.path.join(save_path, "config.json")) + os.rmdir(save_path) \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..5542a79 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,75 @@ +import os +import tempfile +import pytest +import torch.nn as nn +from aiia import AIIAConfig + +def test_aiia_config_initialization(): + config = AIIAConfig() + assert config.model_name == "AIIA" + assert config.kernel_size == 3 + assert config.activation_function == "GELU" + assert config.hidden_size == 512 + assert config.num_hidden_layers == 12 + assert config.num_channels == 3 + assert config.learning_rate == 5e-5 + +def test_aiia_config_custom_initialization(): + config = AIIAConfig( + model_name="CustomModel", + kernel_size=5, + activation_function="ReLU", + hidden_size=1024, + num_hidden_layers=8, + num_channels=1, + learning_rate=1e-4 + ) + assert config.model_name == "CustomModel" + assert config.kernel_size == 5 + assert config.activation_function == "ReLU" + assert config.hidden_size == 1024 + assert config.num_hidden_layers == 8 + assert config.num_channels == 1 + assert config.learning_rate == 1e-4 + +def test_aiia_config_invalid_activation_function(): + with pytest.raises(ValueError): + AIIAConfig(activation_function="InvalidFunction") + +def test_aiia_config_to_dict(): + config = AIIAConfig() + config_dict = config.to_dict() + assert isinstance(config_dict, dict) + assert config_dict["model_name"] == "AIIA" + assert config_dict["kernel_size"] == 3 + +def test_aiia_config_save_and_load(): + with tempfile.TemporaryDirectory() as tmpdir: + config = AIIAConfig(model_name="TempModel") + save_path = os.path.join(tmpdir, "config") + config.save(save_path) + + loaded_config = AIIAConfig.load(save_path) + assert loaded_config.model_name == "TempModel" + assert loaded_config.kernel_size == 3 + assert loaded_config.activation_function == "GELU" + +def test_aiia_config_save_and_load_with_custom_attributes(): + with tempfile.TemporaryDirectory() as tmpdir: + config = AIIAConfig(model_name="TempModel", custom_attr="value") + save_path = os.path.join(tmpdir, "config") + config.save(save_path) + + loaded_config = AIIAConfig.load(save_path) + assert loaded_config.model_name == "TempModel" + assert loaded_config.custom_attr == "value" + +def test_aiia_config_save_and_load_with_nested_attributes(): + with tempfile.TemporaryDirectory() as tmpdir: + config = AIIAConfig(model_name="TempModel", nested={"key": "value"}) + save_path = os.path.join(tmpdir, "config") + config.save(save_path) + + loaded_config = AIIAConfig.load(save_path) + assert loaded_config.model_name == "TempModel" + assert loaded_config.nested == {"key": "value"}