45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
import pytest
|
|
from aiunn import aiuNNConfig
|
|
|
|
def test_default_initialization():
|
|
config = aiuNNConfig()
|
|
assert config.upsample_scale == 2
|
|
assert config.upsample_mode == 'bilinear'
|
|
assert not config.upsample_align_corners
|
|
assert len(config.layers) == 1
|
|
assert config.layers[0]['name'] == 'Upsample'
|
|
|
|
def test_custom_initialization():
|
|
custom_config = {
|
|
'some_key': 'some_value',
|
|
}
|
|
config = aiuNNConfig(base_config=custom_config, upsample_scale=3, upsample_mode='nearest', upsample_align_corners=True)
|
|
assert config.upsample_scale == 3
|
|
assert config.upsample_mode == 'nearest'
|
|
assert config.upsample_align_corners
|
|
assert len(config.layers) == 1
|
|
assert config.layers[0]['name'] == 'Upsample'
|
|
|
|
def test_add_upsample_layer():
|
|
config = aiuNNConfig()
|
|
config.add_upsample_layer()
|
|
assert len(config.layers) == 1
|
|
|
|
def test_upsample_layer_not_duplicated():
|
|
config = aiuNNConfig()
|
|
initial_length = len(config.layers)
|
|
# Remove all existing 'Upsample' layers.
|
|
for layer in list(config.layers):
|
|
if layer.get('name') == 'Upsample':
|
|
config.layers.remove(layer)
|
|
config.add_upsample_layer()
|
|
assert len(config.layers) == 1
|
|
|
|
def test_base_config_with_to_dict():
|
|
class MockBaseConfig:
|
|
def to_dict(self):
|
|
return {'base_key': 'base_value'}
|
|
|
|
base_config = MockBaseConfig()
|
|
config = aiuNNConfig(base_config=base_config)
|
|
assert config.to_dict()['base_key'] == 'base_value' |