diff --git a/example.py b/example.py index 8dbb67e..2ce0b6a 100644 --- a/example.py +++ b/example.py @@ -1,10 +1,12 @@ -from aiia.model import AIIABase -from aiia.model import AIIAConfig -from aiia.pretrain import Pretrainer +from src.aiia.model import AIIAmoe +from src.aiia.model import AIIAConfig +from src.aiia.pretrain import Pretrainer # Create your model -config = AIIAConfig(model_name="AIIA-Base-512x20k") -model = AIIABase(config) +config = AIIAConfig(num_experts=5) +model = AIIAmoe(config) +model.save_pretrained("test") +model = AIIAmoe.from_pretrained("test") # Initialize pretrainer with the model pretrainer = Pretrainer(model, learning_rate=1e-4, config=config)