48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
import os
|
|
import tempfile
|
|
from aiia import AIIABase, AIIAConfig
|
|
from aiunn import aiuNN, aiuNNConfig
|
|
|
|
def test_save_and_load_model():
|
|
# Create a temporary directory to save the model
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
# Create configurations and build a base model
|
|
config = AIIAConfig()
|
|
ai_config = aiuNNConfig()
|
|
base_model = AIIABase(config)
|
|
upsampler = aiuNN(base_model, config=ai_config)
|
|
|
|
# Save the model
|
|
save_path = os.path.join(tmpdirname, "model")
|
|
upsampler.save(save_path)
|
|
|
|
# Load the model
|
|
loaded_upsampler = aiuNN.load(save_path)
|
|
|
|
# Verify that the loaded model is the same as the original model
|
|
assert isinstance(loaded_upsampler, aiuNN)
|
|
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__
|
|
|
|
def test_save_and_load_model_with_precision():
|
|
# Create a temporary directory to save the model
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
# Create configurations and build a base model
|
|
config = AIIAConfig()
|
|
ai_config = aiuNNConfig()
|
|
base_model = AIIABase(config)
|
|
upsampler = aiuNN(base_model, config=ai_config)
|
|
|
|
# Save the model
|
|
save_path = os.path.join(tmpdirname, "model")
|
|
upsampler.save(save_path)
|
|
|
|
# Load the model with precision 'bf16'
|
|
loaded_upsampler = aiuNN.load(save_path, precision="bf16")
|
|
|
|
# Verify that the loaded model is the same as the original model
|
|
assert isinstance(loaded_upsampler, aiuNN)
|
|
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__
|
|
|
|
if __name__ == "__main__":
|
|
test_save_and_load_model()
|
|
test_save_and_load_model_with_precision() |