Compare commits
2 Commits
217a1f85d7
...
4e6a02cdc9
Author | SHA1 | Date |
---|---|---|
|
4e6a02cdc9 | |
|
9d89c6c534 |
|
@ -10,7 +10,7 @@ include = '\.pyi?$'
|
|||
|
||||
[project]
|
||||
name = "aiia"
|
||||
version = "0.1.6"
|
||||
version = "0.2.0"
|
||||
description = "AIIA Deep Learning Model Implementation"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[metadata]
|
||||
name = aiia
|
||||
version = 0.1.6
|
||||
version = 0.2.0
|
||||
author = Falko Habel
|
||||
author_email = falko.habel@gmx.de
|
||||
description = AIIA deep learning model implementation
|
||||
|
|
|
@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
|
|||
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||
|
||||
|
||||
__version__ = "0.1.6"
|
||||
__version__ = "0.2.0"
|
||||
|
|
|
@ -1,133 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
|
||||
|
||||
def test_aiiabase_creation():
|
||||
config = AIIAConfig()
|
||||
model = AIIABase(config)
|
||||
assert isinstance(model, AIIABase)
|
||||
|
||||
def test_aiiabase_save_load():
|
||||
config = AIIAConfig()
|
||||
model = AIIABase(config)
|
||||
save_path = "test_aiiabase_save_load"
|
||||
|
||||
# Save the model
|
||||
model.save(save_path)
|
||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||
|
||||
# Load the model
|
||||
loaded_model = AIIABase.load(save_path)
|
||||
|
||||
# Check if the loaded model is an instance of AIIABase
|
||||
assert isinstance(loaded_model, AIIABase)
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(save_path, "model.pth"))
|
||||
os.remove(os.path.join(save_path, "config.json"))
|
||||
os.rmdir(save_path)
|
||||
|
||||
def test_aiiabase_shared_creation():
|
||||
config = AIIAConfig()
|
||||
model = AIIABaseShared(config)
|
||||
assert isinstance(model, AIIABaseShared)
|
||||
|
||||
def test_aiiabase_shared_save_load():
|
||||
config = AIIAConfig()
|
||||
model = AIIABaseShared(config)
|
||||
save_path = "test_aiiabase_shared_save_load"
|
||||
|
||||
# Save the model
|
||||
model.save(save_path)
|
||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||
|
||||
# Load the model
|
||||
loaded_model = AIIABaseShared.load(save_path)
|
||||
|
||||
# Check if the loaded model is an instance of AIIABaseShared
|
||||
assert isinstance(loaded_model, AIIABaseShared)
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(save_path, "model.pth"))
|
||||
os.remove(os.path.join(save_path, "config.json"))
|
||||
os.rmdir(save_path)
|
||||
|
||||
def test_aiiaexpert_creation():
|
||||
config = AIIAConfig()
|
||||
model = AIIAExpert(config)
|
||||
assert isinstance(model, AIIAExpert)
|
||||
|
||||
def test_aiiaexpert_save_load():
|
||||
config = AIIAConfig()
|
||||
model = AIIAExpert(config)
|
||||
save_path = "test_aiiaexpert_save_load"
|
||||
|
||||
# Save the model
|
||||
model.save(save_path)
|
||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||
|
||||
# Load the model
|
||||
loaded_model = AIIAExpert.load(save_path)
|
||||
|
||||
# Check if the loaded model is an instance of AIIAExpert
|
||||
assert isinstance(loaded_model, AIIAExpert)
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(save_path, "model.pth"))
|
||||
os.remove(os.path.join(save_path, "config.json"))
|
||||
os.rmdir(save_path)
|
||||
|
||||
def test_aiiamoe_creation():
|
||||
config = AIIAConfig()
|
||||
model = AIIAmoe(config, num_experts=5)
|
||||
assert isinstance(model, AIIAmoe)
|
||||
|
||||
def test_aiiamoe_save_load():
|
||||
config = AIIAConfig()
|
||||
model = AIIAmoe(config, num_experts=5)
|
||||
save_path = "test_aiiamoe_save_load"
|
||||
|
||||
# Save the model
|
||||
model.save(save_path)
|
||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||
|
||||
# Load the model
|
||||
loaded_model = AIIAmoe.load(save_path)
|
||||
|
||||
# Check if the loaded model is an instance of AIIAmoe
|
||||
assert isinstance(loaded_model, AIIAmoe)
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(save_path, "model.pth"))
|
||||
os.remove(os.path.join(save_path, "config.json"))
|
||||
os.rmdir(save_path)
|
||||
|
||||
def test_aiiachunked_creation():
|
||||
config = AIIAConfig()
|
||||
model = AIIAchunked(config)
|
||||
assert isinstance(model, AIIAchunked)
|
||||
|
||||
def test_aiiachunked_save_load():
|
||||
config = AIIAConfig()
|
||||
model = AIIAchunked(config)
|
||||
save_path = "test_aiiachunked_save_load"
|
||||
|
||||
# Save the model
|
||||
model.save(save_path)
|
||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||
|
||||
# Load the model
|
||||
loaded_model = AIIAchunked.load(save_path)
|
||||
|
||||
# Check if the loaded model is an instance of AIIAchunked
|
||||
assert isinstance(loaded_model, AIIAchunked)
|
||||
|
||||
# Clean up
|
||||
os.remove(os.path.join(save_path, "model.pth"))
|
||||
os.remove(os.path.join(save_path, "config.json"))
|
||||
os.rmdir(save_path)
|
|
@ -1,75 +0,0 @@
|
|||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from aiia import AIIAConfig
|
||||
|
||||
def test_aiia_config_initialization():
|
||||
config = AIIAConfig()
|
||||
assert config.model_name == "AIIA"
|
||||
assert config.kernel_size == 3
|
||||
assert config.activation_function == "GELU"
|
||||
assert config.hidden_size == 512
|
||||
assert config.num_hidden_layers == 12
|
||||
assert config.num_channels == 3
|
||||
assert config.learning_rate == 5e-5
|
||||
|
||||
def test_aiia_config_custom_initialization():
|
||||
config = AIIAConfig(
|
||||
model_name="CustomModel",
|
||||
kernel_size=5,
|
||||
activation_function="ReLU",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=8,
|
||||
num_channels=1,
|
||||
learning_rate=1e-4
|
||||
)
|
||||
assert config.model_name == "CustomModel"
|
||||
assert config.kernel_size == 5
|
||||
assert config.activation_function == "ReLU"
|
||||
assert config.hidden_size == 1024
|
||||
assert config.num_hidden_layers == 8
|
||||
assert config.num_channels == 1
|
||||
assert config.learning_rate == 1e-4
|
||||
|
||||
def test_aiia_config_invalid_activation_function():
|
||||
with pytest.raises(ValueError):
|
||||
AIIAConfig(activation_function="InvalidFunction")
|
||||
|
||||
def test_aiia_config_to_dict():
|
||||
config = AIIAConfig()
|
||||
config_dict = config.to_dict()
|
||||
assert isinstance(config_dict, dict)
|
||||
assert config_dict["model_name"] == "AIIA"
|
||||
assert config_dict["kernel_size"] == 3
|
||||
|
||||
def test_aiia_config_save_and_load():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = AIIAConfig(model_name="TempModel")
|
||||
save_path = os.path.join(tmpdir, "config")
|
||||
config.save(save_path)
|
||||
|
||||
loaded_config = AIIAConfig.load(save_path)
|
||||
assert loaded_config.model_name == "TempModel"
|
||||
assert loaded_config.kernel_size == 3
|
||||
assert loaded_config.activation_function == "GELU"
|
||||
|
||||
def test_aiia_config_save_and_load_with_custom_attributes():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = AIIAConfig(model_name="TempModel", custom_attr="value")
|
||||
save_path = os.path.join(tmpdir, "config")
|
||||
config.save(save_path)
|
||||
|
||||
loaded_config = AIIAConfig.load(save_path)
|
||||
assert loaded_config.model_name == "TempModel"
|
||||
assert loaded_config.custom_attr == "value"
|
||||
|
||||
def test_aiia_config_save_and_load_with_nested_attributes():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = AIIAConfig(model_name="TempModel", nested={"key": "value"})
|
||||
save_path = os.path.join(tmpdir, "config")
|
||||
config.save(save_path)
|
||||
|
||||
loaded_config = AIIAConfig.load(save_path)
|
||||
assert loaded_config.model_name == "TempModel"
|
||||
assert loaded_config.nested == {"key": "value"}
|
Loading…
Reference in New Issue