aiuNN/tests/upsampler/test_aiunn.py

31 lines
1.1 KiB
Python

import os
import tempfile
from aiia import AIIABase, AIIAConfig
from aiunn import aiuNN, aiuNNConfig
def test_save_and_load_model():
# Create a temporary directory to save the model
with tempfile.TemporaryDirectory() as tmpdirname:
# Create configurations and build a base model
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model
save_path = os.path.join(tmpdirname, "model")
upsampler.save_pretrained(save_path)
# Load the model
loaded_upsampler = aiuNN.from_pretrained(save_path)
# Verify that the loaded model is the same as the original model
assert isinstance(loaded_upsampler, aiuNN)
assert loaded_upsampler.config.hidden_size == upsampler.config.hidden_size
assert loaded_upsampler.config._activation_function == upsampler.config._activation_function
assert loaded_upsampler.config.architectures == upsampler.config.architectures
if __name__ == "__main__":
test_save_and_load_model()