feat/tf_support #37

Merged
Fabel merged 13 commits from feat/tf_support into develop 2025-04-16 20:59:48 +00:00
3 changed files with 93 additions and 119 deletions
Showing only changes of commit d4db1ef116 - Show all commits

View File

@ -1,159 +1,133 @@
import os
import torch
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAConfig, AIIASparseMoe
def test_aiiabase_creation():
config = AIIAConfig()
model = AIIABase(config)
assert isinstance(model, AIIABase)
def test_aiiabase_save_load():
def test_aiiabase_save_pretrained_from_pretrained():
config = AIIAConfig()
model = AIIABase(config)
save_path = "test_aiiabase_save_load"
save_pretrained_path = "test_aiiabase_save_pretrained_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# save_pretrained the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIABase.load(save_path)
loaded_model = AIIABase.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIABase
assert isinstance(loaded_model, AIIABase)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)
def test_aiiabase_shared_creation():
config = AIIAConfig()
model = AIIABaseShared(config)
assert isinstance(model, AIIABaseShared)
def test_aiiabase_shared_save_load():
def test_aiiabase_shared_save_pretrained_from_pretrained():
config = AIIAConfig()
model = AIIABaseShared(config)
save_path = "test_aiiabase_shared_save_load"
save_pretrained_path = "test_aiiabase_shared_save_pretrained_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# save_pretrained the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIABaseShared.load(save_path)
loaded_model = AIIABaseShared.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIABaseShared
assert isinstance(loaded_model, AIIABaseShared)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)
def test_aiiaexpert_creation():
config = AIIAConfig()
model = AIIAExpert(config)
assert isinstance(model, AIIAExpert)
def test_aiiaexpert_save_load():
def test_aiiaexpert_save_pretrained_from_pretrained():
config = AIIAConfig()
model = AIIAExpert(config)
save_path = "test_aiiaexpert_save_load"
save_pretrained_path = "test_aiiaexpert_save_pretrained_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# save_pretrained the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIAExpert.load(save_path)
loaded_model = AIIAExpert.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIAExpert
assert isinstance(loaded_model, AIIAExpert)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)
def test_aiiamoe_creation():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
assert isinstance(model, AIIAmoe)
def test_aiiamoe_save_load():
def test_aiiamoe_save_pretrained_from_pretrained():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
save_path = "test_aiiamoe_save_load"
save_pretrained_path = "test_aiiamoe_save_pretrained_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# save_pretrained the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIAmoe.load(save_path)
loaded_model = AIIAmoe.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIAmoe
assert isinstance(loaded_model, AIIAmoe)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)
def test_aiiasparsemoe_creation():
config = AIIAConfig()
model = AIIASparseMoe(config, num_experts=5, top_k=2)
assert isinstance(model, AIIASparseMoe)
def test_aiiasparsemoe_save_load():
def test_aiiasparsemoe_save_pretrained_from_pretrained():
config = AIIAConfig()
model = AIIASparseMoe(config, num_experts=3, top_k=1)
save_path = "test_aiiasparsemoe_save_load"
save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# save_pretrained the model
model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model
loaded_model = AIIASparseMoe.load(save_path)
loaded_model = AIIASparseMoe.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIASparseMoe
assert isinstance(loaded_model, AIIASparseMoe)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiachunked_creation():
config = AIIAConfig()
model = AIIAchunked(config)
assert isinstance(model, AIIAchunked)
def test_aiiachunked_save_load():
config = AIIAConfig()
model = AIIAchunked(config)
save_path = "test_aiiachunked_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAchunked.load(save_path)
# Check if the loaded model is an instance of AIIAchunked
assert isinstance(loaded_model, AIIAchunked)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_pretrained_path)

View File

@ -6,7 +6,7 @@ from aiia import AIIAConfig
def test_aiia_config_initialization():
config = AIIAConfig()
assert config.model_name == "AIIA"
assert config.model_type == "AIIA"
assert config.kernel_size == 3
assert config.activation_function == "GELU"
assert config.hidden_size == 512
@ -16,7 +16,7 @@ def test_aiia_config_initialization():
def test_aiia_config_custom_initialization():
config = AIIAConfig(
model_name="CustomModel",
model_type="CustomModel",
kernel_size=5,
activation_function="ReLU",
hidden_size=1024,
@ -24,7 +24,7 @@ def test_aiia_config_custom_initialization():
num_channels=1,
learning_rate=1e-4
)
assert config.model_name == "CustomModel"
assert config.model_type == "CustomModel"
assert config.kernel_size == 5
assert config.activation_function == "ReLU"
assert config.hidden_size == 1024
@ -40,36 +40,36 @@ def test_aiia_config_to_dict():
config = AIIAConfig()
config_dict = config.to_dict()
assert isinstance(config_dict, dict)
assert config_dict["model_name"] == "AIIA"
assert config_dict["model_type"] == "AIIA"
assert config_dict["kernel_size"] == 3
def test_aiia_config_save_and_load():
def test_aiia_config_save_pretrained_and_from_pretrained():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
config = AIIAConfig(model_type="TempModel")
save_pretrained_path = os.path.join(tmpdir, "config")
config.save_pretrained(save_pretrained_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
assert loaded_config.model_type == "TempModel"
assert loaded_config.kernel_size == 3
assert loaded_config.activation_function == "GELU"
def test_aiia_config_save_and_load_with_custom_attributes():
def test_aiia_config_save_pretrained_and_load_with_custom_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", custom_attr="value")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
config = AIIAConfig(model_type="TempModel", custom_attr="value")
save_pretrained_path = os.path.join(tmpdir, "config")
config.save_pretrained(save_pretrained_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
assert loaded_config.model_type == "TempModel"
assert loaded_config.custom_attr == "value"
def test_aiia_config_save_and_load_with_nested_attributes():
def test_aiia_config_save_pretrained_and_load_with_nested_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", nested={"key": "value"})
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
config = AIIAConfig(model_type="TempModel", nested={"key": "value"})
save_pretrained_path = os.path.join(tmpdir, "config")
config.save_pretrained(save_pretrained_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
assert loaded_config.model_type == "TempModel"
assert loaded_config.nested == {"key": "value"}

View File

@ -94,7 +94,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
# Execute training with patched methods
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \
patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \
patch.object(Pretrainer, 'save_losses') as mock_save_losses, \
patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses, \
patch('builtins.open', mock_open()):
pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2)
@ -104,10 +104,10 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
assert mock_process_batch.call_count == 2
assert mock_validate.call_count == 2
# Check for "Best model saved!" instead of model.save()
mock_print.assert_any_call("Best model saved!")
# Check for "Best model save_pretrainedd!" instead of model.save_pretrained()
mock_print.assert_any_call("Best model save_pretrainedd!")
mock_save_losses.assert_called_once()
mock_save_pretrained_losses.assert_called_once()
# Verify state changes
assert len(pretrainer.train_losses) == 2
@ -153,13 +153,13 @@ def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, 'save_losses') as mock_save_losses:
with patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses:
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify empty loader behavior
assert len(pretrainer.train_losses) == 1
assert pretrainer.train_losses[0] == 0.0
mock_save_losses.assert_called_once()
mock_save_pretrained_losses.assert_called_once()
@patch('pandas.concat')
@patch('pandas.read_parquet')
@ -180,7 +180,7 @@ def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat)
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
patch.object(Pretrainer, 'save_losses'):
patch.object(Pretrainer, 'save_pretrained_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify None batch handling
@ -212,7 +212,7 @@ def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_
custom_batch_size = 16
custom_sample_size = 5000
with patch.object(Pretrainer, 'save_losses'):
with patch.object(Pretrainer, 'save_pretrained_losses'):
pretrainer.train(
['path/to/dataset.parquet'],
output_path=custom_output_path,
@ -233,7 +233,7 @@ def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('builtins.print') # Add this to mock the print function
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
"""Test that model is saved only when validation loss improves."""
"""Test that model is save_pretrainedd only when validation loss improves."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
@ -262,11 +262,11 @@ def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_re
# Test improving validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
patch.object(Pretrainer, 'save_losses'):
patch.object(Pretrainer, 'save_pretrained_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Check for "Best model saved!" 3 times
assert mock_print.call_args_list.count(call("Best model saved!")) == 3
# Check for "Best model save_pretrainedd!" 3 times
assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 3
# Reset for next test
mock_print.reset_mock()
@ -278,11 +278,11 @@ def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_re
# Test fluctuating validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
patch.object(Pretrainer, 'save_losses'):
patch.object(Pretrainer, 'save_pretrained_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Should print "Best model saved!" only on first and third epochs
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
# Should print "Best model save_pretrainedd!" only on first and third epochs
assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 2
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
@ -296,13 +296,13 @@ def test_validate(mock_process_batch):
loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate)
assert loss == 0.5
# Test the save_losses method
@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses')
def test_save_losses(mock_save_losses):
# Test the save_pretrained_losses method
@patch('aiia.pretrain.pretrainer.Pretrainer.save_pretrained_losses')
def test_save_pretrained_losses(mock_save_pretrained_losses):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
pretrainer.train_losses = [0.1, 0.2]
pretrainer.val_losses = [0.3, 0.4]
csv_file = 'losses.csv'
pretrainer.save_losses(csv_file)
mock_save_losses.assert_called_once_with(csv_file)
pretrainer.save_pretrained_losses(csv_file)
mock_save_pretrained_losses.assert_called_once_with(csv_file)