feat/tf_support #37

Merged
Fabel merged 13 commits from feat/tf_support into develop 2025-04-16 20:59:48 +00:00
1 changed files with 9 additions and 7 deletions
Showing only changes of commit 023ca07cf7 - Show all commits

View File

@ -1,9 +1,9 @@
import os
import tempfile
import pytest
import torch.nn as nn
from aiia import AIIAConfig
def test_aiia_config_initialization():
config = AIIAConfig()
assert config.model_type == "AIIA"
@ -12,7 +12,7 @@ def test_aiia_config_initialization():
assert config.hidden_size == 512
assert config.num_hidden_layers == 12
assert config.num_channels == 3
assert config.learning_rate == 5e-5
def test_aiia_config_custom_initialization():
config = AIIAConfig(
@ -21,8 +21,7 @@ def test_aiia_config_custom_initialization():
activation_function="ReLU",
hidden_size=1024,
num_hidden_layers=8,
num_channels=1,
learning_rate=1e-4
num_channels=1
)
assert config.model_type == "CustomModel"
assert config.kernel_size == 5
@ -30,19 +29,20 @@ def test_aiia_config_custom_initialization():
assert config.hidden_size == 1024
assert config.num_hidden_layers == 8
assert config.num_channels == 1
assert config.learning_rate == 1e-4
def test_aiia_config_invalid_activation_function():
with pytest.raises(ValueError):
AIIAConfig(activation_function="InvalidFunction")
def test_aiia_config_to_dict():
config = AIIAConfig()
config_dict = config.to_dict()
assert isinstance(config_dict, dict)
assert config_dict["model_type"] == "AIIA"
assert config_dict["kernel_size"] == 3
def test_aiia_config_save_pretrained_and_from_pretrained():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_type="TempModel")
@ -54,6 +54,7 @@ def test_aiia_config_save_pretrained_and_from_pretrained():
assert loaded_config.kernel_size == 3
assert loaded_config.activation_function == "GELU"
def test_aiia_config_save_pretrained_and_load_with_custom_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_type="TempModel", custom_attr="value")
@ -64,6 +65,7 @@ def test_aiia_config_save_pretrained_and_load_with_custom_attributes():
assert loaded_config.model_type == "TempModel"
assert loaded_config.custom_attr == "value"
def test_aiia_config_save_pretrained_and_load_with_nested_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_type="TempModel", nested={"key": "value"})