31 lines
1.1 KiB
Python
31 lines
1.1 KiB
Python
import os
|
|
import tempfile
|
|
from aiia import AIIABase, AIIAConfig
|
|
from aiunn import aiuNN, aiuNNConfig
|
|
|
|
def test_save_and_load_model():
|
|
# Create a temporary directory to save the model
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
# Create configurations and build a base model
|
|
config = AIIAConfig()
|
|
ai_config = aiuNNConfig()
|
|
base_model = AIIABase(config)
|
|
upsampler = aiuNN(config=ai_config)
|
|
upsampler.load_base_model(base_model)
|
|
# Save the model
|
|
save_path = os.path.join(tmpdirname, "model")
|
|
upsampler.save_pretrained(save_path)
|
|
|
|
# Load the model
|
|
loaded_upsampler = aiuNN.from_pretrained(save_path)
|
|
|
|
# Verify that the loaded model is the same as the original model
|
|
assert isinstance(loaded_upsampler, aiuNN)
|
|
assert loaded_upsampler.config.hidden_size == upsampler.config.hidden_size
|
|
assert loaded_upsampler.config._activation_function == upsampler.config._activation_function
|
|
assert loaded_upsampler.config.architectures == upsampler.config.architectures
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_save_and_load_model()
|