feat/tf_support #37
|
@ -1,9 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import torch.nn as nn
|
|
||||||
from aiia import AIIAConfig
|
from aiia import AIIAConfig
|
||||||
|
|
||||||
|
|
||||||
def test_aiia_config_initialization():
|
def test_aiia_config_initialization():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
assert config.model_type == "AIIA"
|
assert config.model_type == "AIIA"
|
||||||
|
@ -12,7 +12,7 @@ def test_aiia_config_initialization():
|
||||||
assert config.hidden_size == 512
|
assert config.hidden_size == 512
|
||||||
assert config.num_hidden_layers == 12
|
assert config.num_hidden_layers == 12
|
||||||
assert config.num_channels == 3
|
assert config.num_channels == 3
|
||||||
assert config.learning_rate == 5e-5
|
|
||||||
|
|
||||||
def test_aiia_config_custom_initialization():
|
def test_aiia_config_custom_initialization():
|
||||||
config = AIIAConfig(
|
config = AIIAConfig(
|
||||||
|
@ -21,8 +21,7 @@ def test_aiia_config_custom_initialization():
|
||||||
activation_function="ReLU",
|
activation_function="ReLU",
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
num_hidden_layers=8,
|
num_hidden_layers=8,
|
||||||
num_channels=1,
|
num_channels=1
|
||||||
learning_rate=1e-4
|
|
||||||
)
|
)
|
||||||
assert config.model_type == "CustomModel"
|
assert config.model_type == "CustomModel"
|
||||||
assert config.kernel_size == 5
|
assert config.kernel_size == 5
|
||||||
|
@ -30,19 +29,20 @@ def test_aiia_config_custom_initialization():
|
||||||
assert config.hidden_size == 1024
|
assert config.hidden_size == 1024
|
||||||
assert config.num_hidden_layers == 8
|
assert config.num_hidden_layers == 8
|
||||||
assert config.num_channels == 1
|
assert config.num_channels == 1
|
||||||
assert config.learning_rate == 1e-4
|
|
||||||
|
|
||||||
def test_aiia_config_invalid_activation_function():
|
def test_aiia_config_invalid_activation_function():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
AIIAConfig(activation_function="InvalidFunction")
|
AIIAConfig(activation_function="InvalidFunction")
|
||||||
|
|
||||||
|
|
||||||
def test_aiia_config_to_dict():
|
def test_aiia_config_to_dict():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
config_dict = config.to_dict()
|
config_dict = config.to_dict()
|
||||||
assert isinstance(config_dict, dict)
|
assert isinstance(config_dict, dict)
|
||||||
assert config_dict["model_type"] == "AIIA"
|
|
||||||
assert config_dict["kernel_size"] == 3
|
assert config_dict["kernel_size"] == 3
|
||||||
|
|
||||||
|
|
||||||
def test_aiia_config_save_pretrained_and_from_pretrained():
|
def test_aiia_config_save_pretrained_and_from_pretrained():
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
config = AIIAConfig(model_type="TempModel")
|
config = AIIAConfig(model_type="TempModel")
|
||||||
|
@ -54,6 +54,7 @@ def test_aiia_config_save_pretrained_and_from_pretrained():
|
||||||
assert loaded_config.kernel_size == 3
|
assert loaded_config.kernel_size == 3
|
||||||
assert loaded_config.activation_function == "GELU"
|
assert loaded_config.activation_function == "GELU"
|
||||||
|
|
||||||
|
|
||||||
def test_aiia_config_save_pretrained_and_load_with_custom_attributes():
|
def test_aiia_config_save_pretrained_and_load_with_custom_attributes():
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
config = AIIAConfig(model_type="TempModel", custom_attr="value")
|
config = AIIAConfig(model_type="TempModel", custom_attr="value")
|
||||||
|
@ -64,6 +65,7 @@ def test_aiia_config_save_pretrained_and_load_with_custom_attributes():
|
||||||
assert loaded_config.model_type == "TempModel"
|
assert loaded_config.model_type == "TempModel"
|
||||||
assert loaded_config.custom_attr == "value"
|
assert loaded_config.custom_attr == "value"
|
||||||
|
|
||||||
|
|
||||||
def test_aiia_config_save_pretrained_and_load_with_nested_attributes():
|
def test_aiia_config_save_pretrained_and_load_with_nested_attributes():
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
config = AIIAConfig(model_type="TempModel", nested={"key": "value"})
|
config = AIIAConfig(model_type="TempModel", nested={"key": "value"})
|
||||||
|
|
Loading…
Reference in New Issue