import os import torch from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig def test_aiiabase_creation(): config = AIIAConfig() model = AIIABase(config) assert isinstance(model, AIIABase) def test_aiiabase_save_load(): config = AIIAConfig() model = AIIABase(config) save_path = "test_aiiabase_save_load" # Save the model model.save(save_path) assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_path, "config.json")) # Load the model loaded_model = AIIABase.load(save_path) # Check if the loaded model is an instance of AIIABase assert isinstance(loaded_model, AIIABase) # Clean up os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_path, "config.json")) os.rmdir(save_path) def test_aiiabase_shared_creation(): config = AIIAConfig() model = AIIABaseShared(config) assert isinstance(model, AIIABaseShared) def test_aiiabase_shared_save_load(): config = AIIAConfig() model = AIIABaseShared(config) save_path = "test_aiiabase_shared_save_load" # Save the model model.save(save_path) assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_path, "config.json")) # Load the model loaded_model = AIIABaseShared.load(save_path) # Check if the loaded model is an instance of AIIABaseShared assert isinstance(loaded_model, AIIABaseShared) # Clean up os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_path, "config.json")) os.rmdir(save_path) def test_aiiaexpert_creation(): config = AIIAConfig() model = AIIAExpert(config) assert isinstance(model, AIIAExpert) def test_aiiaexpert_save_load(): config = AIIAConfig() model = AIIAExpert(config) save_path = "test_aiiaexpert_save_load" # Save the model model.save(save_path) assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_path, "config.json")) # Load the model loaded_model = AIIAExpert.load(save_path) # Check if the loaded model is an instance of AIIAExpert assert isinstance(loaded_model, AIIAExpert) # Clean up os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_path, "config.json")) os.rmdir(save_path) def test_aiiamoe_creation(): config = AIIAConfig() model = AIIAmoe(config, num_experts=5) assert isinstance(model, AIIAmoe) def test_aiiamoe_save_load(): config = AIIAConfig() model = AIIAmoe(config, num_experts=5) save_path = "test_aiiamoe_save_load" # Save the model model.save(save_path) assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_path, "config.json")) # Load the model loaded_model = AIIAmoe.load(save_path) # Check if the loaded model is an instance of AIIAmoe assert isinstance(loaded_model, AIIAmoe) # Clean up os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_path, "config.json")) os.rmdir(save_path) def test_aiiachunked_creation(): config = AIIAConfig() model = AIIAchunked(config) assert isinstance(model, AIIAchunked) def test_aiiachunked_save_load(): config = AIIAConfig() model = AIIAchunked(config) save_path = "test_aiiachunked_save_load" # Save the model model.save(save_path) assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_path, "config.json")) # Load the model loaded_model = AIIAchunked.load(save_path) # Check if the loaded model is an instance of AIIAchunked assert isinstance(loaded_model, AIIAchunked) # Clean up os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_path, "config.json")) os.rmdir(save_path)