import os import tempfile from aiia import AIIABase, AIIAConfig from aiunn import aiuNN, aiuNNConfig def test_save_and_load_model(): # Create a temporary directory to save the model with tempfile.TemporaryDirectory() as tmpdirname: # Create configurations and build a base model config = AIIAConfig() ai_config = aiuNNConfig() base_model = AIIABase(config) upsampler = aiuNN(base_model, config=ai_config) # Save the model save_path = os.path.join(tmpdirname, "model") upsampler.save(save_path) # Load the model loaded_upsampler = aiuNN.load(save_path) # Verify that the loaded model is the same as the original model assert isinstance(loaded_upsampler, aiuNN) assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__ def test_save_and_load_model_with_precision(): # Create a temporary directory to save the model with tempfile.TemporaryDirectory() as tmpdirname: # Create configurations and build a base model config = AIIAConfig() ai_config = aiuNNConfig() base_model = AIIABase(config) upsampler = aiuNN(base_model, config=ai_config) # Save the model save_path = os.path.join(tmpdirname, "model") upsampler.save(save_path) # Load the model with precision 'bf16' loaded_upsampler = aiuNN.load(save_path, precision="bf16") # Verify that the loaded model is the same as the original model assert isinstance(loaded_upsampler, aiuNN) assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__ if __name__ == "__main__": test_save_and_load_model() test_save_and_load_model_with_precision()