tests with new config

This commit is contained in:
Falko Victor Habel 2025-04-13 22:19:15 +02:00
parent 3e78a595c9
commit 9ec01e2dd0
1 changed files with 15 additions and 15 deletions

View File

@ -12,7 +12,7 @@ def test_aiiabase_save_pretrained_from_pretrained():
model = AIIABase(config) model = AIIABase(config)
save_pretrained_path = "test_aiiabase_save_pretrained_load" save_pretrained_path = "test_aiiabase_save_pretrained_load"
# save_pretrained the model # Save the model
model.save_pretrained(save_pretrained_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
@ -38,7 +38,7 @@ def test_aiiabase_shared_save_pretrained_from_pretrained():
model = AIIABaseShared(config) model = AIIABaseShared(config)
save_pretrained_path = "test_aiiabase_shared_save_pretrained_load" save_pretrained_path = "test_aiiabase_shared_save_pretrained_load"
# save_pretrained the model # Save the model
model.save_pretrained(save_pretrained_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
@ -64,7 +64,7 @@ def test_aiiaexpert_save_pretrained_from_pretrained():
model = AIIAExpert(config) model = AIIAExpert(config)
save_pretrained_path = "test_aiiaexpert_save_pretrained_load" save_pretrained_path = "test_aiiaexpert_save_pretrained_load"
# save_pretrained the model # Save the model
model.save_pretrained(save_pretrained_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
@ -81,16 +81,16 @@ def test_aiiaexpert_save_pretrained_from_pretrained():
os.rmdir(save_pretrained_path) os.rmdir(save_pretrained_path)
def test_aiiamoe_creation(): def test_aiiamoe_creation():
config = AIIAConfig() config = AIIAConfig(num_experts=3)
model = AIIAmoe(config, num_experts=5) model = AIIAmoe(config)
assert isinstance(model, AIIAmoe) assert isinstance(model, AIIAmoe)
def test_aiiamoe_save_pretrained_from_pretrained(): def test_aiiamoe_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig(num_experts=3)
model = AIIAmoe(config, num_experts=5) model = AIIAmoe(config)
save_pretrained_path = "test_aiiamoe_save_pretrained_load" save_pretrained_path = "test_aiiamoe_save_pretrained_load"
# save_pretrained the model # Save the model
model.save_pretrained(save_pretrained_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
@ -107,16 +107,16 @@ def test_aiiamoe_save_pretrained_from_pretrained():
os.rmdir(save_pretrained_path) os.rmdir(save_pretrained_path)
def test_aiiasparsemoe_creation(): def test_aiiasparsemoe_creation():
config = AIIAConfig() config = AIIAConfig(num_experts=5, top_k=2)
model = AIIASparseMoe(config, num_experts=5, top_k=2) model = AIIASparseMoe(config, base_class=AIIABaseShared)
assert isinstance(model, AIIASparseMoe) assert isinstance(model, AIIASparseMoe)
def test_aiiasparsemoe_save_pretrained_from_pretrained(): def test_aiiasparsemoe_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig(num_experts=3, top_k=1)
model = AIIASparseMoe(config, num_experts=3, top_k=1) model = AIIASparseMoe(config)
save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load" save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load"
# save_pretrained the model # Save the model
model.save_pretrained(save_pretrained_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))