import os import torch from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAConfig, AIIASparseMoe def test_aiiabase_creation(): config = AIIAConfig() model = AIIABase(config) assert isinstance(model, AIIABase) def test_aiiabase_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIABase(config) save_pretrained_path = "test_aiiabase_save_pretrained_load" # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model loaded_model = AIIABase.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIABase assert isinstance(loaded_model, AIIABase) # Clean up os.remove(os.path.join(save_pretrained_path, "model.safetensors")) os.remove(os.path.join(save_pretrained_path, "config.json")) os.rmdir(save_pretrained_path) def test_aiiabase_shared_creation(): config = AIIAConfig() model = AIIABaseShared(config) assert isinstance(model, AIIABaseShared) def test_aiiabase_shared_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIABaseShared(config) save_pretrained_path = "test_aiiabase_shared_save_pretrained_load" # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model loaded_model = AIIABaseShared.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIABaseShared assert isinstance(loaded_model, AIIABaseShared) # Clean up os.remove(os.path.join(save_pretrained_path, "model.safetensors")) os.remove(os.path.join(save_pretrained_path, "config.json")) os.rmdir(save_pretrained_path) def test_aiiaexpert_creation(): config = AIIAConfig() model = AIIAExpert(config) assert isinstance(model, AIIAExpert) def test_aiiaexpert_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIAExpert(config) save_pretrained_path = "test_aiiaexpert_save_pretrained_load" # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model loaded_model = AIIAExpert.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIAExpert assert isinstance(loaded_model, AIIAExpert) # Clean up os.remove(os.path.join(save_pretrained_path, "model.safetensors")) os.remove(os.path.join(save_pretrained_path, "config.json")) os.rmdir(save_pretrained_path) def test_aiiamoe_creation(): config = AIIAConfig(num_experts=3) model = AIIAmoe(config) assert isinstance(model, AIIAmoe) def test_aiiamoe_save_pretrained_from_pretrained(): config = AIIAConfig(num_experts=3) model = AIIAmoe(config) save_pretrained_path = "test_aiiamoe_save_pretrained_load" # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model loaded_model = AIIAmoe.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIAmoe assert isinstance(loaded_model, AIIAmoe) # Clean up os.remove(os.path.join(save_pretrained_path, "model.safetensors")) os.remove(os.path.join(save_pretrained_path, "config.json")) os.rmdir(save_pretrained_path) def test_aiiasparsemoe_creation(): config = AIIAConfig(num_experts=5, top_k=2) model = AIIASparseMoe(config, base_class=AIIABaseShared) assert isinstance(model, AIIASparseMoe) def test_aiiasparsemoe_save_pretrained_from_pretrained(): config = AIIAConfig(num_experts=3, top_k=1) model = AIIASparseMoe(config) save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load" # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model loaded_model = AIIASparseMoe.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIASparseMoe assert isinstance(loaded_model, AIIASparseMoe) # Clean up os.remove(os.path.join(save_pretrained_path, "model.safetensors")) os.remove(os.path.join(save_pretrained_path, "config.json")) os.rmdir(save_pretrained_path)