import os import tempfile from aiia import AIIABase, AIIAConfig from aiunn import aiuNN, aiuNNConfig def test_save_and_load_model(): # Create a temporary directory to save the model with tempfile.TemporaryDirectory() as tmpdirname: # Create configurations and build a base model config = AIIAConfig() ai_config = aiuNNConfig() base_model = AIIABase(config) upsampler = aiuNN(config=ai_config) upsampler.load_base_model(base_model) # Save the model save_path = os.path.join(tmpdirname, "model") upsampler.save_pretrained(save_path) # Load the model loaded_upsampler = aiuNN.from_pretrained(save_path) # Verify that the loaded model is the same as the original model assert isinstance(loaded_upsampler, aiuNN) assert loaded_upsampler.config.hidden_size == upsampler.config.hidden_size assert loaded_upsampler.config._activation_function == upsampler.config._activation_function assert loaded_upsampler.config.architectures == upsampler.config.architectures if __name__ == "__main__": test_save_and_load_model()