From 0852ddb1097865f000daf8db3f538ee5f36332c1 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 11 Apr 2025 22:36:41 +0200 Subject: [PATCH 01/23] removed the chunked and recursive model, since they are now depreacted. for the transfomer switch I decided to focus on my idea with Sparse Mixutre of Experts with shared Params. --- src/aiia/model/Model.py | 57 ----------------------------------------- 1 file changed, 57 deletions(-) diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index f21a0db..abcc34a 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -294,63 +294,6 @@ class AIIASparseMoe(AIIAmoe): return torch.cat(merged_outputs, dim=0) -class AIIAchunked(AIIA): - def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config - - # Update config with new parameters if provided - self.config.patch_size = patch_size - - # Initialize base CNN for processing each patch using the specified base class - if issubclass(base_class, AIIABase): - self.base_cnn = AIIABase(self.config, **kwargs) - elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared - self.base_cnn = AIIABaseShared(self.config, **kwargs) - else: - raise ValueError("Invalid base class") - - def forward(self, x): - patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) - patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size) - patch_outputs = [] - - for p in torch.split(patches, 1, dim=2): - p = p.squeeze(2) - po = self.base_cnn(p) - patch_outputs.append(po) - - combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0) - return combined_output - -class AIIArecursive(AIIA): - def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs): - - super().__init__(config=config, **kwargs) - self.config = self.config - - # Pass recursion_depth as a kwarg to the config - self.config.recursion_depth = recursion_depth - - # Initialize chunked CNN with updated config - self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs) - - def forward(self, x, depth=0): - if depth == self.recursion_depth: - return self.chunked_cnn(x) - else: - patches = x.unfold(2, 16, 16).unfold(3, 16, 16) - patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16) - processed_patches = [] - - for p in torch.split(patches, 1, dim=2): - p = p.squeeze(2) - pp = self.forward(p, depth + 1) - processed_patches.append(pp) - - combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0) - return combined_output - if __name__ =="__main__": config = AIIAConfig() model = AIIAmoe(config, num_experts=5) From 040ac478b994262238193ef15d730b33016de73c Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 11 Apr 2025 22:37:44 +0200 Subject: [PATCH 02/23] removed loading and saving functions since tf will take over --- src/aiia/model/Model.py | 54 +---------------------------------------- 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index abcc34a..118acb0 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -1,5 +1,6 @@ from .config import AIIAConfig from torch import nn +from transformers import PretrainedModel import torch import os import copy @@ -22,59 +23,6 @@ class AIIA(nn.Module): torch.save(self.state_dict(), f"{path}/model.pth") self.config.save(path) - @classmethod - def load(cls, path, precision: str = None, strict: bool = True, **kwargs): - config = AIIAConfig.load(path) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # Load the state dict to analyze structure - model_dict = torch.load(f"{path}/model.pth", map_location=device) - - # Special handling for AIIAmoe - detect number of experts from state_dict - if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs: - # Find maximum expert index - max_expert_idx = -1 - for key in model_dict.keys(): - if key.startswith("experts."): - parts = key.split(".") - if len(parts) > 1: - try: - expert_idx = int(parts[1]) - max_expert_idx = max(max_expert_idx, expert_idx) - except ValueError: - pass - - if max_expert_idx >= 0: - # experts.X keys found, use max_expert_idx + 1 as num_experts - kwargs["num_experts"] = max_expert_idx + 1 - - # Create model with detected structural parameters - model = cls(config, **kwargs) - - # Handle precision conversion - dtype = None - if precision is not None: - if precision.lower() == 'fp16': - dtype = torch.float16 - elif precision.lower() == 'bf16': - if device == 'cuda' and not torch.cuda.is_bf16_supported(): - warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.") - dtype = torch.float16 - else: - dtype = torch.bfloat16 - else: - raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") - - if dtype is not None: - for key, param in model_dict.items(): - if torch.is_tensor(param): - model_dict[key] = param.to(dtype) - - # Load state dict with strict parameter for flexibility - model.load_state_dict(model_dict, strict=strict) - return model - - class AIIABase(AIIA): def __init__(self, config: AIIAConfig, **kwargs): super().__init__(config=config, **kwargs) From bf915a4dcae19e99bf79e91105ecb31f93e4de18 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 11 Apr 2025 22:40:00 +0200 Subject: [PATCH 03/23] added transformers to requirements --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8e2d666..60fe329 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ pytest pillow pandas torchvision -pyarrow \ No newline at end of file +pyarrow +transformers>=4.48.0 \ No newline at end of file From d4db1ef116abca1142eb7433e56ff63db8c5dbce Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 12 Apr 2025 21:44:03 +0200 Subject: [PATCH 04/23] updated tests for transfomer switch --- tests/model/test_aiia.py | 128 ++++++++++++------------------ tests/model/test_config.py | 44 +++++----- tests/pretrain/test_pretrainer.py | 40 +++++----- 3 files changed, 93 insertions(+), 119 deletions(-) diff --git a/tests/model/test_aiia.py b/tests/model/test_aiia.py index bffa616..6177298 100644 --- a/tests/model/test_aiia.py +++ b/tests/model/test_aiia.py @@ -1,159 +1,133 @@ import os import torch -from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe +from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAConfig, AIIASparseMoe def test_aiiabase_creation(): config = AIIAConfig() model = AIIABase(config) assert isinstance(model, AIIABase) -def test_aiiabase_save_load(): +def test_aiiabase_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIABase(config) - save_path = "test_aiiabase_save_load" + save_pretrained_path = "test_aiiabase_save_pretrained_load" - # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + # save_pretrained the model + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIABase.load(save_path) + loaded_model = AIIABase.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIABase assert isinstance(loaded_model, AIIABase) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) def test_aiiabase_shared_creation(): config = AIIAConfig() model = AIIABaseShared(config) assert isinstance(model, AIIABaseShared) -def test_aiiabase_shared_save_load(): +def test_aiiabase_shared_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIABaseShared(config) - save_path = "test_aiiabase_shared_save_load" + save_pretrained_path = "test_aiiabase_shared_save_pretrained_load" - # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + # save_pretrained the model + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIABaseShared.load(save_path) + loaded_model = AIIABaseShared.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIABaseShared assert isinstance(loaded_model, AIIABaseShared) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) def test_aiiaexpert_creation(): config = AIIAConfig() model = AIIAExpert(config) assert isinstance(model, AIIAExpert) -def test_aiiaexpert_save_load(): +def test_aiiaexpert_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIAExpert(config) - save_path = "test_aiiaexpert_save_load" + save_pretrained_path = "test_aiiaexpert_save_pretrained_load" - # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + # save_pretrained the model + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIAExpert.load(save_path) + loaded_model = AIIAExpert.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIAExpert assert isinstance(loaded_model, AIIAExpert) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) def test_aiiamoe_creation(): config = AIIAConfig() model = AIIAmoe(config, num_experts=5) assert isinstance(model, AIIAmoe) -def test_aiiamoe_save_load(): +def test_aiiamoe_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIAmoe(config, num_experts=5) - save_path = "test_aiiamoe_save_load" + save_pretrained_path = "test_aiiamoe_save_pretrained_load" - # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + # save_pretrained the model + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIAmoe.load(save_path) + loaded_model = AIIAmoe.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIAmoe assert isinstance(loaded_model, AIIAmoe) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) def test_aiiasparsemoe_creation(): config = AIIAConfig() model = AIIASparseMoe(config, num_experts=5, top_k=2) assert isinstance(model, AIIASparseMoe) -def test_aiiasparsemoe_save_load(): +def test_aiiasparsemoe_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIASparseMoe(config, num_experts=3, top_k=1) - save_path = "test_aiiasparsemoe_save_load" + save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load" - # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + # save_pretrained the model + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIASparseMoe.load(save_path) + loaded_model = AIIASparseMoe.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIASparseMoe assert isinstance(loaded_model, AIIASparseMoe) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) - -def test_aiiachunked_creation(): - config = AIIAConfig() - model = AIIAchunked(config) - assert isinstance(model, AIIAchunked) - -def test_aiiachunked_save_load(): - config = AIIAConfig() - model = AIIAchunked(config) - save_path = "test_aiiachunked_save_load" - - # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) - - # Load the model - loaded_model = AIIAchunked.load(save_path) - - # Check if the loaded model is an instance of AIIAchunked - assert isinstance(loaded_model, AIIAchunked) - - # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) \ No newline at end of file + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) diff --git a/tests/model/test_config.py b/tests/model/test_config.py index 5542a79..93f6f4a 100644 --- a/tests/model/test_config.py +++ b/tests/model/test_config.py @@ -6,7 +6,7 @@ from aiia import AIIAConfig def test_aiia_config_initialization(): config = AIIAConfig() - assert config.model_name == "AIIA" + assert config.model_type == "AIIA" assert config.kernel_size == 3 assert config.activation_function == "GELU" assert config.hidden_size == 512 @@ -16,7 +16,7 @@ def test_aiia_config_initialization(): def test_aiia_config_custom_initialization(): config = AIIAConfig( - model_name="CustomModel", + model_type="CustomModel", kernel_size=5, activation_function="ReLU", hidden_size=1024, @@ -24,7 +24,7 @@ def test_aiia_config_custom_initialization(): num_channels=1, learning_rate=1e-4 ) - assert config.model_name == "CustomModel" + assert config.model_type == "CustomModel" assert config.kernel_size == 5 assert config.activation_function == "ReLU" assert config.hidden_size == 1024 @@ -40,36 +40,36 @@ def test_aiia_config_to_dict(): config = AIIAConfig() config_dict = config.to_dict() assert isinstance(config_dict, dict) - assert config_dict["model_name"] == "AIIA" + assert config_dict["model_type"] == "AIIA" assert config_dict["kernel_size"] == 3 -def test_aiia_config_save_and_load(): +def test_aiia_config_save_pretrained_and_from_pretrained(): with tempfile.TemporaryDirectory() as tmpdir: - config = AIIAConfig(model_name="TempModel") - save_path = os.path.join(tmpdir, "config") - config.save(save_path) + config = AIIAConfig(model_type="TempModel") + save_pretrained_path = os.path.join(tmpdir, "config") + config.save_pretrained(save_pretrained_path) - loaded_config = AIIAConfig.load(save_path) - assert loaded_config.model_name == "TempModel" + loaded_config = AIIAConfig.from_pretrained(save_pretrained_path) + assert loaded_config.model_type == "TempModel" assert loaded_config.kernel_size == 3 assert loaded_config.activation_function == "GELU" -def test_aiia_config_save_and_load_with_custom_attributes(): +def test_aiia_config_save_pretrained_and_load_with_custom_attributes(): with tempfile.TemporaryDirectory() as tmpdir: - config = AIIAConfig(model_name="TempModel", custom_attr="value") - save_path = os.path.join(tmpdir, "config") - config.save(save_path) + config = AIIAConfig(model_type="TempModel", custom_attr="value") + save_pretrained_path = os.path.join(tmpdir, "config") + config.save_pretrained(save_pretrained_path) - loaded_config = AIIAConfig.load(save_path) - assert loaded_config.model_name == "TempModel" + loaded_config = AIIAConfig.from_pretrained(save_pretrained_path) + assert loaded_config.model_type == "TempModel" assert loaded_config.custom_attr == "value" -def test_aiia_config_save_and_load_with_nested_attributes(): +def test_aiia_config_save_pretrained_and_load_with_nested_attributes(): with tempfile.TemporaryDirectory() as tmpdir: - config = AIIAConfig(model_name="TempModel", nested={"key": "value"}) - save_path = os.path.join(tmpdir, "config") - config.save(save_path) + config = AIIAConfig(model_type="TempModel", nested={"key": "value"}) + save_pretrained_path = os.path.join(tmpdir, "config") + config.save_pretrained(save_pretrained_path) - loaded_config = AIIAConfig.load(save_path) - assert loaded_config.model_name == "TempModel" + loaded_config = AIIAConfig.from_pretrained(save_pretrained_path) + assert loaded_config.model_type == "TempModel" assert loaded_config.nested == {"key": "value"} diff --git a/tests/pretrain/test_pretrainer.py b/tests/pretrain/test_pretrainer.py index db4c73f..8f4283c 100644 --- a/tests/pretrain/test_pretrainer.py +++ b/tests/pretrain/test_pretrainer.py @@ -94,7 +94,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea # Execute training with patched methods with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \ patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \ - patch.object(Pretrainer, 'save_losses') as mock_save_losses, \ + patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses, \ patch('builtins.open', mock_open()): pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2) @@ -104,10 +104,10 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea assert mock_process_batch.call_count == 2 assert mock_validate.call_count == 2 - # Check for "Best model saved!" instead of model.save() - mock_print.assert_any_call("Best model saved!") + # Check for "Best model save_pretrainedd!" instead of model.save_pretrained() + mock_print.assert_any_call("Best model save_pretrainedd!") - mock_save_losses.assert_called_once() + mock_save_pretrained_losses.assert_called_once() # Verify state changes assert len(pretrainer.train_losses) == 2 @@ -153,13 +153,13 @@ def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat): pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() - with patch.object(Pretrainer, 'save_losses') as mock_save_losses: + with patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses: pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) # Verify empty loader behavior assert len(pretrainer.train_losses) == 1 assert pretrainer.train_losses[0] == 0.0 - mock_save_losses.assert_called_once() + mock_save_pretrained_losses.assert_called_once() @patch('pandas.concat') @patch('pandas.read_parquet') @@ -180,7 +180,7 @@ def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat) pretrainer.optimizer = MagicMock() with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \ - patch.object(Pretrainer, 'save_losses'): + patch.object(Pretrainer, 'save_pretrained_losses'): pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) # Verify None batch handling @@ -212,7 +212,7 @@ def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_ custom_batch_size = 16 custom_sample_size = 5000 - with patch.object(Pretrainer, 'save_losses'): + with patch.object(Pretrainer, 'save_pretrained_losses'): pretrainer.train( ['path/to/dataset.parquet'], output_path=custom_output_path, @@ -233,7 +233,7 @@ def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_ @patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('builtins.print') # Add this to mock the print function def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat): - """Test that model is saved only when validation loss improves.""" + """Test that model is save_pretrainedd only when validation loss improves.""" real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) mock_read_parquet.return_value.head.return_value = real_df mock_concat.return_value = real_df @@ -262,11 +262,11 @@ def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_re # Test improving validation loss with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \ - patch.object(Pretrainer, 'save_losses'): + patch.object(Pretrainer, 'save_pretrained_losses'): pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - # Check for "Best model saved!" 3 times - assert mock_print.call_args_list.count(call("Best model saved!")) == 3 + # Check for "Best model save_pretrainedd!" 3 times + assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 3 # Reset for next test mock_print.reset_mock() @@ -278,11 +278,11 @@ def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_re # Test fluctuating validation loss with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \ - patch.object(Pretrainer, 'save_losses'): + patch.object(Pretrainer, 'save_pretrained_losses'): pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - # Should print "Best model saved!" only on first and third epochs - assert mock_print.call_args_list.count(call("Best model saved!")) == 2 + # Should print "Best model save_pretrainedd!" only on first and third epochs + assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 2 @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') @@ -296,13 +296,13 @@ def test_validate(mock_process_batch): loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate) assert loss == 0.5 -# Test the save_losses method -@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses') -def test_save_losses(mock_save_losses): +# Test the save_pretrained_losses method +@patch('aiia.pretrain.pretrainer.Pretrainer.save_pretrained_losses') +def test_save_pretrained_losses(mock_save_pretrained_losses): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) pretrainer.train_losses = [0.1, 0.2] pretrainer.val_losses = [0.3, 0.4] csv_file = 'losses.csv' - pretrainer.save_losses(csv_file) - mock_save_losses.assert_called_once_with(csv_file) \ No newline at end of file + pretrainer.save_pretrained_losses(csv_file) + mock_save_pretrained_losses.assert_called_once_with(csv_file) \ No newline at end of file From 30695154b4202248b6141cb3019461b1da5ddd9d Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 12 Apr 2025 21:44:33 +0200 Subject: [PATCH 05/23] removed depcreated models from init files --- src/aiia/__init__.py | 2 +- src/aiia/model/__init__.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index 838c765..a4e0d44 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -1,4 +1,4 @@ -from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAmoe, AIIASparseMoe, AIIArecursive +from .model.Model import AIIABase, AIIABaseShared, AIIAmoe, AIIASparseMoe from .model.config import AIIAConfig from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead diff --git a/src/aiia/model/__init__.py b/src/aiia/model/__init__.py index a45512a..88f3211 100644 --- a/src/aiia/model/__init__.py +++ b/src/aiia/model/__init__.py @@ -1,20 +1,15 @@ from .Model import ( AIIABase, AIIABaseShared, - AIIAchunked, AIIAmoe, AIIASparseMoe, - AIIArecursive ) from .config import AIIAConfig __all__ = [ "AIIABase", "AIIABaseShared", - "AIIAchunked", "AIIAmoe", "AIIASparseMoe", - "AIIArecursive", "AIIAConfig", - ] \ No newline at end of file From 22e5d0023e8a34a0ea4dfbc2c175742dcb8aceca Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 12 Apr 2025 21:44:53 +0200 Subject: [PATCH 06/23] updated pretrainer to feature PreTrainedModel --- src/aiia/pretrain/pretrainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 42ba4b8..93df9dd 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -3,7 +3,7 @@ from torch import nn import csv import pandas as pd from tqdm import tqdm -from ..model.Model import AIIA +from transformers import PreTrainedModel from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader import os @@ -21,12 +21,12 @@ class ProjectionHead(nn.Module): return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: - def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): + def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. Args: - model (AIIA): The model instance to pretrain + model (PreTrainedModel): The model instance to pretrain learning_rate (float): Learning rate for optimization config (dict): Model configuration containing hidden_size """ @@ -186,11 +186,11 @@ class Pretrainer: if val_loss < best_val_loss: best_val_loss = val_loss - self.model.save(output_path) - print("Best model saved!") + self.model.save_pretrained(output_path) + print("Best model save_pretrainedd!") losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') - self.save_losses(losses_path) + self.save_pretrained_losses(losses_path) def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" @@ -216,8 +216,8 @@ class Pretrainer: return avg_val_loss - def save_losses(self, csv_file): - """Save training and validation losses to a CSV file.""" + def save_pretrained_losses(self, csv_file): + """save_pretrained training and validation losses to a CSV file.""" data = list(zip( range(1, len(self.train_losses) + 1), self.train_losses, From ba6da9ef020498a84a35de362249dc13c6e7bd84 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 13 Apr 2025 22:17:46 +0200 Subject: [PATCH 07/23] restored full example --- example.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/example.py b/example.py index 8dbb67e..2ce0b6a 100644 --- a/example.py +++ b/example.py @@ -1,10 +1,12 @@ -from aiia.model import AIIABase -from aiia.model import AIIAConfig -from aiia.pretrain import Pretrainer +from src.aiia.model import AIIAmoe +from src.aiia.model import AIIAConfig +from src.aiia.pretrain import Pretrainer # Create your model -config = AIIAConfig(model_name="AIIA-Base-512x20k") -model = AIIABase(config) +config = AIIAConfig(num_experts=5) +model = AIIAmoe(config) +model.save_pretrained("test") +model = AIIAmoe.from_pretrained("test") # Initialize pretrainer with the model pretrainer = Pretrainer(model, learning_rate=1e-4, config=config) From a8ddf2b55971859a6549fe363f2513521cc0a2fa Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 13 Apr 2025 22:18:26 +0200 Subject: [PATCH 08/23] updated config to not handle ustom model_name; removed ln rate and made it transformer compatible --- src/aiia/model/config.py | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/src/aiia/model/config.py b/src/aiia/model/config.py index a83329a..7bf3d93 100644 --- a/src/aiia/model/config.py +++ b/src/aiia/model/config.py @@ -1,29 +1,25 @@ -import torch +from transformers import PretrainedConfig import torch.nn as nn -import json -import os +class AIIAConfig(PretrainedConfig): + model_type = "AIIA" # Add this class attribute -class AIIAConfig: def __init__( self, - model_name: str = "AIIA", kernel_size: int = 3, activation_function: str = "GELU", hidden_size: int = 512, num_hidden_layers: int = 12, num_channels: int = 3, - learning_rate: float = 5e-5, **kwargs ): - self.model_name = model_name + super().__init__(**kwargs) self.kernel_size = kernel_size self.activation_function = activation_function self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_channels = num_channels - self.learning_rate = learning_rate - + # Store additional keyword arguments as attributes for key, value in kwargs.items(): setattr(self, key, value) @@ -50,17 +46,4 @@ class AIIAConfig: elif isinstance(value, dict): return {k: serialize(v) for k, v in value.items()} return value - return {k: serialize(v) for k, v in self.__dict__.items()} - - def save(self, file_path): - if not os.path.exists(file_path): - os.makedirs(file_path, exist_ok=True) - with open(os.path.join(file_path, "config.json"), "w") as f: - # Save the recursively converted dictionary. - json.dump(self.to_dict(), f, indent=4) - - @classmethod - def load(cls, file_path): - with open(os.path.join(file_path, "config.json"), "r") as f: - config_dict = json.load(f) - return cls(**config_dict) \ No newline at end of file + return {k: serialize(v) for k, v in self.__dict__.items()} \ No newline at end of file From 3e78a595c9fb73ffef3927564c0851c461564597 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 13 Apr 2025 22:18:59 +0200 Subject: [PATCH 09/23] removed base aiia class and replaced it with transformer support --- src/aiia/model/Model.py | 113 +++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 65 deletions(-) diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index 118acb0..1cdd8dc 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -1,68 +1,49 @@ from .config import AIIAConfig from torch import nn -from transformers import PretrainedModel +from transformers import PreTrainedModel import torch -import os import copy -import warnings -class AIIA(nn.Module): - def __init__(self, config: AIIAConfig, **kwargs): - super(AIIA, self).__init__() - # Create a deep copy of the configuration to avoid sharing - self.config = copy.deepcopy(config) - - # Update the config with any additional keyword arguments - for key, value in kwargs.items(): - setattr(self.config, key, value) - - def save(self, path: str): - if not os.path.exists(path): - os.makedirs(path, exist_ok=True) - torch.save(self.state_dict(), f"{path}/model.pth") - self.config.save(path) - -class AIIABase(AIIA): - def __init__(self, config: AIIAConfig, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config +class AIIABase(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + + def __init__(self, config: AIIAConfig): + super().__init__(config) # Initialize layers based on configuration layers = [] - in_channels = self.config.num_channels + in_channels = config.num_channels - for _ in range(self.config.num_hidden_layers): + for _ in range(config.num_hidden_layers): layers.extend([ - nn.Conv2d(in_channels, self.config.hidden_size, - kernel_size=self.config.kernel_size, padding=1), - getattr(nn, self.config.activation_function)(), + nn.Conv2d(in_channels, config.hidden_size, + kernel_size=config.kernel_size, padding=1), + getattr(nn, config.activation_function)(), nn.MaxPool2d(kernel_size=1, stride=1) ]) - in_channels = self.config.hidden_size + in_channels = config.hidden_size self.cnn = nn.Sequential(*layers) def forward(self, x): return self.cnn(x) -class AIIABaseShared(AIIA): - def __init__(self, config: AIIAConfig, **kwargs): +class AIIABaseShared(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + + def __init__(self, config: AIIAConfig): + super().__init__(config) """ Initialize the AIIABaseShared model. Args: config (AIIAConfig): Configuration object containing model parameters. - **kwargs: Additional keyword arguments to override configuration settings. """ - super().__init__(config=config, **kwargs) - - # Update configuration with new parameters if provided - self. config = copy.deepcopy(config) - - for key, value in kwargs.items(): - setattr(self.config, key, value) - + super().__init__(config=config) + # Initialize the network components self._initialize_network() self._initialize_activation_andPooling() @@ -120,16 +101,17 @@ class AIIABaseShared(AIIA): return out -class AIIAExpert(AIIA): - def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config +class AIIAExpert(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + def __init__(self, config: AIIAConfig, base_class=AIIABase): + super().__init__(config=config) # Initialize base CNN with configuration and chosen base class if issubclass(base_class, AIIABase): - self.base_cnn = AIIABase(self.config, **kwargs) + self.base_cnn = AIIABase(self.config) elif issubclass(base_class, AIIABaseShared): - self.base_cnn = AIIABaseShared(self.config, **kwargs) + self.base_cnn = AIIABaseShared(self.config) else: raise ValueError("Invalid base class") @@ -146,31 +128,31 @@ class AIIAExpert(AIIA): # Process input through the base CNN return self.base_cnn(x) -class AIIAmoe(AIIA): - def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) +class AIIAmoe(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + + def __init__(self, config: AIIAConfig, base_class=AIIABase): + super().__init__(config=config) self.config = config - # Update the config to include the number of experts. - self.config.num_experts = num_experts - - # Initialize multiple experts from the chosen base class. + # Get num_experts directly from config instead of parameter + num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config + + # Initialize multiple experts from the chosen base class self.experts = nn.ModuleList([ - AIIAExpert(self.config, base_class=base_class, **kwargs) + AIIAExpert(self.config, base_class=base_class) for _ in range(num_experts) ]) - # To generate gating weights, we first need to determine the feature dimension. - # Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W, - # we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 224). - gate_in_features = 512 # Adjust this if your expert output changes. + gate_in_features = self.config.hidden_size - # Create a gating network that maps the aggregated features to num_experts weights. + # Create a gating network that maps the aggregated features to num_experts weights self.gate = nn.Sequential( nn.Linear(gate_in_features, num_experts), nn.Softmax(dim=1) ) - + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the Mixture-of-Experts model. @@ -209,9 +191,10 @@ class AIIAmoe(AIIA): class AIIASparseMoe(AIIAmoe): - def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs): - super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs) - self.top_k = top_k + config_class = AIIAConfig + base_model_prefix = "AIIA" + def __init__(self, config: AIIAConfig, base_class=AIIABase): + super().__init__(config=config, base_class=base_class) def forward(self, x: torch.Tensor) -> torch.Tensor: # Compute the gate_weights similar to standard moe. @@ -221,7 +204,7 @@ class AIIASparseMoe(AIIAmoe): gate_weights = self.gate(gate_input) # Select the top-k experts for each input based on gating weights. - _, top_k_indices = gate_weights.topk(self.top_k, dim=-1) + _, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1) # Initialize a list to store outputs from selected experts. merged_outputs = [] @@ -245,4 +228,4 @@ class AIIASparseMoe(AIIAmoe): if __name__ =="__main__": config = AIIAConfig() model = AIIAmoe(config, num_experts=5) - model.save("test") \ No newline at end of file + model.save_pretrained("test") \ No newline at end of file From 9ec01e2dd0f9a10d0f6c2f87992abf308a6f2bf0 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 13 Apr 2025 22:19:15 +0200 Subject: [PATCH 10/23] tests with new config --- tests/model/test_aiia.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/model/test_aiia.py b/tests/model/test_aiia.py index 6177298..4891472 100644 --- a/tests/model/test_aiia.py +++ b/tests/model/test_aiia.py @@ -12,7 +12,7 @@ def test_aiiabase_save_pretrained_from_pretrained(): model = AIIABase(config) save_pretrained_path = "test_aiiabase_save_pretrained_load" - # save_pretrained the model + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -38,7 +38,7 @@ def test_aiiabase_shared_save_pretrained_from_pretrained(): model = AIIABaseShared(config) save_pretrained_path = "test_aiiabase_shared_save_pretrained_load" - # save_pretrained the model + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -64,7 +64,7 @@ def test_aiiaexpert_save_pretrained_from_pretrained(): model = AIIAExpert(config) save_pretrained_path = "test_aiiaexpert_save_pretrained_load" - # save_pretrained the model + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -81,16 +81,16 @@ def test_aiiaexpert_save_pretrained_from_pretrained(): os.rmdir(save_pretrained_path) def test_aiiamoe_creation(): - config = AIIAConfig() - model = AIIAmoe(config, num_experts=5) + config = AIIAConfig(num_experts=3) + model = AIIAmoe(config) assert isinstance(model, AIIAmoe) def test_aiiamoe_save_pretrained_from_pretrained(): - config = AIIAConfig() - model = AIIAmoe(config, num_experts=5) + config = AIIAConfig(num_experts=3) + model = AIIAmoe(config) save_pretrained_path = "test_aiiamoe_save_pretrained_load" - # save_pretrained the model + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -107,16 +107,16 @@ def test_aiiamoe_save_pretrained_from_pretrained(): os.rmdir(save_pretrained_path) def test_aiiasparsemoe_creation(): - config = AIIAConfig() - model = AIIASparseMoe(config, num_experts=5, top_k=2) + config = AIIAConfig(num_experts=5, top_k=2) + model = AIIASparseMoe(config, base_class=AIIABaseShared) assert isinstance(model, AIIASparseMoe) def test_aiiasparsemoe_save_pretrained_from_pretrained(): - config = AIIAConfig() - model = AIIASparseMoe(config, num_experts=3, top_k=1) + config = AIIAConfig(num_experts=3, top_k=1) + model = AIIASparseMoe(config) save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load" - - # save_pretrained the model + + # Save the model model.save_pretrained(save_pretrained_path) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) @@ -130,4 +130,4 @@ def test_aiiasparsemoe_save_pretrained_from_pretrained(): # Clean up os.remove(os.path.join(save_pretrained_path, "model.safetensors")) os.remove(os.path.join(save_pretrained_path, "config.json")) - os.rmdir(save_pretrained_path) + os.rmdir(save_pretrained_path) \ No newline at end of file From 023ca07cf7693a059a31541b247731726030492f Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 13 Apr 2025 22:19:38 +0200 Subject: [PATCH 11/23] tests to pass new params --- tests/model/test_config.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/model/test_config.py b/tests/model/test_config.py index 93f6f4a..5789660 100644 --- a/tests/model/test_config.py +++ b/tests/model/test_config.py @@ -1,9 +1,9 @@ import os import tempfile import pytest -import torch.nn as nn from aiia import AIIAConfig + def test_aiia_config_initialization(): config = AIIAConfig() assert config.model_type == "AIIA" @@ -12,7 +12,7 @@ def test_aiia_config_initialization(): assert config.hidden_size == 512 assert config.num_hidden_layers == 12 assert config.num_channels == 3 - assert config.learning_rate == 5e-5 + def test_aiia_config_custom_initialization(): config = AIIAConfig( @@ -21,8 +21,7 @@ def test_aiia_config_custom_initialization(): activation_function="ReLU", hidden_size=1024, num_hidden_layers=8, - num_channels=1, - learning_rate=1e-4 + num_channels=1 ) assert config.model_type == "CustomModel" assert config.kernel_size == 5 @@ -30,19 +29,20 @@ def test_aiia_config_custom_initialization(): assert config.hidden_size == 1024 assert config.num_hidden_layers == 8 assert config.num_channels == 1 - assert config.learning_rate == 1e-4 + def test_aiia_config_invalid_activation_function(): with pytest.raises(ValueError): AIIAConfig(activation_function="InvalidFunction") + def test_aiia_config_to_dict(): config = AIIAConfig() config_dict = config.to_dict() assert isinstance(config_dict, dict) - assert config_dict["model_type"] == "AIIA" assert config_dict["kernel_size"] == 3 + def test_aiia_config_save_pretrained_and_from_pretrained(): with tempfile.TemporaryDirectory() as tmpdir: config = AIIAConfig(model_type="TempModel") @@ -54,6 +54,7 @@ def test_aiia_config_save_pretrained_and_from_pretrained(): assert loaded_config.kernel_size == 3 assert loaded_config.activation_function == "GELU" + def test_aiia_config_save_pretrained_and_load_with_custom_attributes(): with tempfile.TemporaryDirectory() as tmpdir: config = AIIAConfig(model_type="TempModel", custom_attr="value") @@ -64,6 +65,7 @@ def test_aiia_config_save_pretrained_and_load_with_custom_attributes(): assert loaded_config.model_type == "TempModel" assert loaded_config.custom_attr == "value" + def test_aiia_config_save_pretrained_and_load_with_nested_attributes(): with tempfile.TemporaryDirectory() as tmpdir: config = AIIAConfig(model_type="TempModel", nested={"key": "value"}) @@ -72,4 +74,4 @@ def test_aiia_config_save_pretrained_and_load_with_nested_attributes(): loaded_config = AIIAConfig.from_pretrained(save_pretrained_path) assert loaded_config.model_type == "TempModel" - assert loaded_config.nested == {"key": "value"} + assert loaded_config.nested == {"key": "value"} \ No newline at end of file From 1190f3c05ff08c9c0192e25051c7b1549ca86fc6 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 13 Apr 2025 22:19:59 +0200 Subject: [PATCH 12/23] updated config to match new version --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 333838b..90e2a0b 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ from aiia.model import AIIAConfig from aiia.pretrain import Pretrainer # Create your model -config = AIIAConfig(model_name="AIIA-Base-512x20k") +config = AIIAConfig(model_type="AIIA-Base-512x20k") model = AIIABase(config) # Initialize pretrainer with the model From 716f6e247019b46cd685ec72e15d804bbc8af208 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 13 Apr 2025 22:20:42 +0200 Subject: [PATCH 13/23] increased version number --- pyproject.toml | 2 +- setup.cfg | 2 +- src/aiia/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e4d1f2a..2220a98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = '\.pyi?$' [project] name = "aiia" -version = "0.2.1" +version = "0.3.0" description = "AIIA Deep Learning Model Implementation" readme = "README.md" authors = [ diff --git a/setup.cfg b/setup.cfg index eeb961d..26614c8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiia -version = 0.2.1 +version = 0.3.0 author = Falko Habel author_email = falko.habel@gmx.de description = AIIA deep learning model implementation diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index a4e0d44..699a42f 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead -__version__ = "0.2.1" +__version__ = "0.3.0" From 5457bca9639d11dee98e90d885300c51db93a170 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 14 Apr 2025 22:00:50 +0200 Subject: [PATCH 14/23] in between safe --- src/aiia/pretrain/pretrainer.py | 79 ++++++++++++++++++++++++++++----- 1 file changed, 69 insertions(+), 10 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 42ba4b8..94d412d 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -1,6 +1,8 @@ import torch from torch import nn import csv +import datetime +import time import pandas as pd from tqdm import tqdm from ..model.Model import AIIA @@ -112,18 +114,21 @@ class Pretrainer: return batch_loss - def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): - """ - Train the model using multiple specified datasets. - - Args: - dataset_paths (list): List of paths to parquet datasets - num_epochs (int): Number of training epochs - batch_size (int): Batch size for training - sample_size (int): Number of samples to use from each dataset - """ + def train(self, dataset_paths, output_path="AIIA", column="image_bytes", + num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): + """Train the model using multiple specified datasets.""" if not dataset_paths: raise ValueError("No dataset paths provided") + + # Checkpoint tracking variables + last_checkpoint_time = time.time() + checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds + last_22_date = None + recent_checkpoints = [] + + # Create checkpoint directory if specified + if checkpoint_dir is not None: + os.makedirs(checkpoint_dir, exist_ok=True) # Read and merge all datasets dataframes = [] @@ -166,6 +171,59 @@ class Pretrainer: if batch_data is None: continue + # Check if we need to save a checkpoint + current_time = time.time() + current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time + today = current_dt.date() + + # Regular 2-hour checkpoint + if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval: + checkpoint_path = os.path.join( + checkpoint_dir, + f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt" + ) + torch.save({ + 'epoch': epoch + 1, + 'batch': batch_count, + 'model_state_dict': self.model.state_dict(), + 'projection_head_state_dict': self.projection_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': self.train_losses, + 'val_losses': self.val_losses, + }, checkpoint_path) + + # Track and maintain only 3 recent checkpoints + recent_checkpoints.append(checkpoint_path) + if len(recent_checkpoints) > 3: + oldest = recent_checkpoints.pop(0) + if os.path.exists(oldest): + os.remove(oldest) + + last_checkpoint_time = current_time + print(f"Checkpoint saved at {checkpoint_path}") + + # Special 22:00 checkpoint + is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10 + + if checkpoint_dir and is_22_oclock and last_22_date != today: + checkpoint_path = os.path.join( + checkpoint_dir, + f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" + ) + torch.save({ + 'epoch': epoch + 1, + 'batch': batch_count, + 'model_state_dict': self.model.state_dict(), + 'projection_head_state_dict': self.projection_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': self.train_losses, + 'val_losses': self.val_losses, + }, checkpoint_path) + + last_22_date = today + print(f"22:00 Checkpoint saved at {checkpoint_path}") + + # Process the batch self.optimizer.zero_grad() batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) @@ -192,6 +250,7 @@ class Pretrainer: losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') self.save_losses(losses_path) + def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" val_loss = 0.0 From 47b42c3ab3dd1b266c423274820a205c2bc9827c Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 14 Apr 2025 22:06:40 +0200 Subject: [PATCH 15/23] abstraction checkpoint saving --- src/aiia/pretrain/pretrainer.py | 83 +++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 94d412d..1c706f7 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -114,9 +114,55 @@ class Pretrainer: return batch_loss + def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name): + """Save a model checkpoint. + + Args: + checkpoint_dir (str): Directory to save the checkpoint + epoch (int): Current epoch number + batch_count (int): Current batch count + checkpoint_name (str): Name for the checkpoint file + + Returns: + str: Path to the saved checkpoint + """ + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + checkpoint_data = { + 'epoch': epoch + 1, + 'batch': batch_count, + 'model_state_dict': self.model.state_dict(), + 'projection_head_state_dict': self.projection_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': self.train_losses, + 'val_losses': self.val_losses, + } + torch.save(checkpoint_data, checkpoint_path) + return checkpoint_path + def train(self, dataset_paths, output_path="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): - """Train the model using multiple specified datasets.""" + """Train the model using multiple specified datasets. + + Args: + dataset_paths (List[str]): List of paths to parquet dataset files + output_path (str, optional): Path to save the trained model. Defaults to "AIIA". + column (str, optional): Column name containing image data. Defaults to "image_bytes". + num_epochs (int, optional): Number of training epochs. Defaults to 3. + batch_size (int, optional): Size of training batches. Defaults to 2. + sample_size (int, optional): Number of samples to use from each dataset. Defaults to 10000. + checkpoint_dir (str, optional): Directory to save checkpoints. If None, no checkpoints are saved. + + Raises: + ValueError: If no dataset paths are provided or if no valid datasets could be loaded. + + The function performs the following: + 1. Loads and merges multiple parquet datasets + 2. Trains the model using denoising and rotation tasks + 3. Validates the model performance + 4. Saves checkpoints at regular intervals (every 2 hours) and at 22:00 + 5. Maintains only the 3 most recent regular checkpoints + 6. Saves the best model based on validation loss + """ if not dataset_paths: raise ValueError("No dataset paths provided") @@ -129,7 +175,6 @@ class Pretrainer: # Create checkpoint directory if specified if checkpoint_dir is not None: os.makedirs(checkpoint_dir, exist_ok=True) - # Read and merge all datasets dataframes = [] for path in dataset_paths: @@ -175,22 +220,11 @@ class Pretrainer: current_time = time.time() current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time today = current_dt.date() - + # Regular 2-hour checkpoint if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval: - checkpoint_path = os.path.join( - checkpoint_dir, - f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt" - ) - torch.save({ - 'epoch': epoch + 1, - 'batch': batch_count, - 'model_state_dict': self.model.state_dict(), - 'projection_head_state_dict': self.projection_head.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'train_losses': self.train_losses, - 'val_losses': self.val_losses, - }, checkpoint_path) + checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt" + checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) # Track and maintain only 3 recent checkpoints recent_checkpoints.append(checkpoint_path) @@ -206,23 +240,12 @@ class Pretrainer: is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10 if checkpoint_dir and is_22_oclock and last_22_date != today: - checkpoint_path = os.path.join( - checkpoint_dir, - f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" - ) - torch.save({ - 'epoch': epoch + 1, - 'batch': batch_count, - 'model_state_dict': self.model.state_dict(), - 'projection_head_state_dict': self.projection_head.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'train_losses': self.train_losses, - 'val_losses': self.val_losses, - }, checkpoint_path) - + checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" + checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) last_22_date = today print(f"22:00 Checkpoint saved at {checkpoint_path}") + # Process the batch self.optimizer.zero_grad() batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) From 9a8cefa37c547feb74706086377c7db40db06656 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 15 Apr 2025 22:17:55 +0200 Subject: [PATCH 16/23] added preloading from checkpoint with batch and epoch --- src/aiia/pretrain/pretrainer.py | 152 +++++++++++++++++++++++++------- 1 file changed, 120 insertions(+), 32 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 1c706f7..6a840ef 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -139,43 +139,110 @@ class Pretrainer: torch.save(checkpoint_data, checkpoint_path) return checkpoint_path - def train(self, dataset_paths, output_path="AIIA", column="image_bytes", - num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): - """Train the model using multiple specified datasets. + def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None): + """ + Check for checkpoints and load if available. Args: - dataset_paths (List[str]): List of paths to parquet dataset files - output_path (str, optional): Path to save the trained model. Defaults to "AIIA". - column (str, optional): Column name containing image data. Defaults to "image_bytes". - num_epochs (int, optional): Number of training epochs. Defaults to 3. - batch_size (int, optional): Size of training batches. Defaults to 2. - sample_size (int, optional): Number of samples to use from each dataset. Defaults to 10000. - checkpoint_dir (str, optional): Directory to save checkpoints. If None, no checkpoints are saved. - - Raises: - ValueError: If no dataset paths are provided or if no valid datasets could be loaded. + checkpoint_dir (str): Directory where checkpoints are stored + specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent. - The function performs the following: - 1. Loads and merges multiple parquet datasets - 2. Trains the model using denoising and rotation tasks - 3. Validates the model performance - 4. Saves checkpoints at regular intervals (every 2 hours) and at 22:00 - 5. Maintains only the 3 most recent regular checkpoints - 6. Saves the best model based on validation loss + Returns: + tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise """ + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + + # If a specific checkpoint is requested + if specific_checkpoint: + checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint) + if os.path.exists(checkpoint_path): + return self._load_checkpoint_file(checkpoint_path) + else: + print(f"Specified checkpoint {specific_checkpoint} not found.") + return None + + # Find all checkpoint files + checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")] + + if not checkpoint_files: + print("No checkpoints found in directory.") + return None + + # Find the most recent checkpoint + checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) + most_recent = checkpoint_files[0] + checkpoint_path = os.path.join(checkpoint_dir, most_recent) + + return self._load_checkpoint_file(checkpoint_path) + + def _load_checkpoint_file(self, checkpoint_path): + """ + Load a specific checkpoint file. + + Args: + checkpoint_path (str): Path to the checkpoint file + + Returns: + tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise + """ + try: + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + # Load model state + self.model.load_state_dict(checkpoint['model_state_dict']) + + # Load projection head state + self.projection_head.load_state_dict(checkpoint['projection_head_state_dict']) + + # Load optimizer state + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + # Load loss history + self.train_losses = checkpoint.get('train_losses', []) + self.val_losses = checkpoint.get('val_losses', []) + + loaded_epoch = checkpoint['epoch'] + loaded_batch = checkpoint['batch'] + + print(f"Checkpoint loaded from {checkpoint_path}") + print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}") + + return loaded_epoch, loaded_batch + + except Exception as e: + print(f"Error loading checkpoint: {e}") + return None + + + def train(self, dataset_paths, output_path="AIIA", column="image_bytes", + num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): + """Train the model using multiple specified datasets with checkpoint resumption support.""" if not dataset_paths: raise ValueError("No dataset paths provided") - # Checkpoint tracking variables + # Initialize checkpoint tracking variables last_checkpoint_time = time.time() checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds last_22_date = None recent_checkpoints = [] - # Create checkpoint directory if specified + # Initialize resumption variables + start_epoch = 0 + start_batch = 0 + resume_training = False + + # Check for existing checkpoint and load if available if checkpoint_dir is not None: os.makedirs(checkpoint_dir, exist_ok=True) - # Read and merge all datasets + checkpoint_info = self.load_checkpoint(checkpoint_dir) + if checkpoint_info: + start_epoch, start_batch = checkpoint_info + resume_training = True + # Adjust epoch to be 0-indexed for the loop + start_epoch -= 1 + + # Load and merge datasets dataframes = [] for path in dataset_paths: try: @@ -198,11 +265,13 @@ class Pretrainer: collate_fn=self.safe_collate ) + # Initialize loss functions and tracking variables criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() best_val_loss = float('inf') - - for epoch in range(num_epochs): + + # Main training loop + for epoch in range(start_epoch, num_epochs): print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) @@ -212,9 +281,21 @@ class Pretrainer: total_train_loss = 0.0 batch_count = 0 - for batch_data in tqdm(aiia_loader.train_loader): + # Convert data loader to enumerated list for batch tracking and resumption + train_batches = list(enumerate(aiia_loader.train_loader)) + + # Determine how many batches to skip if resuming from checkpoint + skip_batches = start_batch if (epoch == start_epoch and resume_training) else 0 + + # Process batches with proper resumption handling + for i, batch_data in tqdm(train_batches[skip_batches:], + initial=skip_batches, + total=len(train_batches)): if batch_data is None: continue + + # Use i+1 as the actual batch count (to match 1-indexed batch numbers in checkpoints) + current_batch = i + 1 # Check if we need to save a checkpoint current_time = time.time() @@ -223,8 +304,8 @@ class Pretrainer: # Regular 2-hour checkpoint if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval: - checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt" - checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) + checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{current_batch}.pt" + checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name) # Track and maintain only 3 recent checkpoints recent_checkpoints.append(checkpoint_path) @@ -236,16 +317,15 @@ class Pretrainer: last_checkpoint_time = current_time print(f"Checkpoint saved at {checkpoint_path}") - # Special 22:00 checkpoint - is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10 + # Special 22:00 checkpoint (considering it's currently 10:15 PM) + is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15 if checkpoint_dir and is_22_oclock and last_22_date != today: checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" - checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) + checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name) last_22_date = today print(f"22:00 Checkpoint saved at {checkpoint_path}") - # Process the batch self.optimizer.zero_grad() batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) @@ -256,6 +336,12 @@ class Pretrainer: total_train_loss += batch_loss.item() batch_count += 1 + # Reset batch skipping after completing the resumed epoch + if resume_training and epoch == start_epoch: + resume_training = False + start_batch = 0 + + # Calculate and store training loss avg_train_loss = total_train_loss / max(batch_count, 1) self.train_losses.append(avg_train_loss) print(f"Training Loss: {avg_train_loss:.4f}") @@ -265,11 +351,13 @@ class Pretrainer: self.projection_head.eval() val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) + # Save best model based on validation loss if val_loss < best_val_loss: best_val_loss = val_loss self.model.save(output_path) print("Best model saved!") + # Save training history losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') self.save_losses(losses_path) From 09662d6102044f28dcb193bfb53168e63393c88c Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 15 Apr 2025 22:42:28 +0200 Subject: [PATCH 17/23] simplified code functions and abstraction to training methods --- src/aiia/pretrain/pretrainer.py | 224 +++++++++++++++++--------------- 1 file changed, 116 insertions(+), 108 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 6a840ef..f94af2c 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -216,23 +216,54 @@ class Pretrainer: def train(self, dataset_paths, output_path="AIIA", column="image_bytes", - num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): + num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): """Train the model using multiple specified datasets with checkpoint resumption support.""" if not dataset_paths: raise ValueError("No dataset paths provided") - - # Initialize checkpoint tracking variables - last_checkpoint_time = time.time() - checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds - last_22_date = None - recent_checkpoints = [] - - # Initialize resumption variables + + self._initialize_checkpoint_variables() + start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir) + + dataframes = self._load_and_merge_datasets(dataset_paths, sample_size) + aiia_loader = self._initialize_data_loader(dataframes, column, batch_size) + + criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions() + + for epoch in range(start_epoch, num_epochs): + print(f"\nEpoch {epoch+1}/{num_epochs}") + print("-" * 20) + total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader, + start_batch if (epoch == start_epoch and resume_training) else 0, + criterion_denoise, + criterion_rotate) + + avg_train_loss = total_train_loss / max(batch_count, 1) + self.train_losses.append(avg_train_loss) + print(f"Training Loss: {avg_train_loss:.4f}") + + val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate) + + if val_loss < best_val_loss: + best_val_loss = val_loss + self.model.save(output_path) + print("Best model saved!") + + losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') + self.save_losses(losses_path) + + def _initialize_checkpoint_variables(self): + """Initialize checkpoint tracking variables.""" + self.last_checkpoint_time = time.time() + self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds + self.last_22_date = None + self.recent_checkpoints = [] + + def _load_checkpoints(self, checkpoint_dir): + """Load checkpoints and return start epoch, batch, and resumption flag.""" start_epoch = 0 start_batch = 0 resume_training = False - - # Check for existing checkpoint and load if available + if checkpoint_dir is not None: os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_info = self.load_checkpoint(checkpoint_dir) @@ -241,8 +272,11 @@ class Pretrainer: resume_training = True # Adjust epoch to be 0-indexed for the loop start_epoch -= 1 - - # Load and merge datasets + + return start_epoch, start_batch, resume_training + + def _load_and_merge_datasets(self, dataset_paths, sample_size): + """Load and merge datasets.""" dataframes = [] for path in dataset_paths: try: @@ -250,14 +284,15 @@ class Pretrainer: dataframes.append(df) except Exception as e: print(f"Error loading dataset {path}: {e}") - + if not dataframes: raise ValueError("No valid datasets could be loaded") - - merged_df = pd.concat(dataframes, ignore_index=True) - # Initialize data loader - aiia_loader = AIIADataLoader( + return pd.concat(dataframes, ignore_index=True) + + def _initialize_data_loader(self, merged_df, column, batch_size): + """Initialize the data loader.""" + return AIIADataLoader( merged_df, column=column, batch_size=batch_size, @@ -265,102 +300,75 @@ class Pretrainer: collate_fn=self.safe_collate ) - # Initialize loss functions and tracking variables + def _initialize_loss_functions(self): + """Initialize loss functions and tracking variables.""" criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() best_val_loss = float('inf') - - # Main training loop - for epoch in range(start_epoch, num_epochs): - print(f"\nEpoch {epoch+1}/{num_epochs}") - print("-" * 20) - - # Training phase - self.model.train() - self.projection_head.train() - total_train_loss = 0.0 - batch_count = 0 - - # Convert data loader to enumerated list for batch tracking and resumption - train_batches = list(enumerate(aiia_loader.train_loader)) - - # Determine how many batches to skip if resuming from checkpoint - skip_batches = start_batch if (epoch == start_epoch and resume_training) else 0 - - # Process batches with proper resumption handling - for i, batch_data in tqdm(train_batches[skip_batches:], - initial=skip_batches, - total=len(train_batches)): - if batch_data is None: - continue - - # Use i+1 as the actual batch count (to match 1-indexed batch numbers in checkpoints) - current_batch = i + 1 - - # Check if we need to save a checkpoint - current_time = time.time() - current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time - today = current_dt.date() - - # Regular 2-hour checkpoint - if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval: - checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{current_batch}.pt" - checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name) - - # Track and maintain only 3 recent checkpoints - recent_checkpoints.append(checkpoint_path) - if len(recent_checkpoints) > 3: - oldest = recent_checkpoints.pop(0) - if os.path.exists(oldest): - os.remove(oldest) - - last_checkpoint_time = current_time - print(f"Checkpoint saved at {checkpoint_path}") - - # Special 22:00 checkpoint (considering it's currently 10:15 PM) - is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15 - - if checkpoint_dir and is_22_oclock and last_22_date != today: - checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" - checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name) - last_22_date = today - print(f"22:00 Checkpoint saved at {checkpoint_path}") - - # Process the batch - self.optimizer.zero_grad() - batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) - - if batch_loss > 0: - batch_loss.backward() - self.optimizer.step() - total_train_loss += batch_loss.item() - batch_count += 1 - - # Reset batch skipping after completing the resumed epoch - if resume_training and epoch == start_epoch: - resume_training = False - start_batch = 0 - - # Calculate and store training loss - avg_train_loss = total_train_loss / max(batch_count, 1) - self.train_losses.append(avg_train_loss) - print(f"Training Loss: {avg_train_loss:.4f}") + return criterion_denoise, criterion_rotate, best_val_loss - # Validation phase - self.model.eval() - self.projection_head.eval() - val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) - - # Save best model based on validation loss - if val_loss < best_val_loss: - best_val_loss = val_loss - self.model.save(output_path) - print("Best model saved!") + def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate): + """Handle the training phase.""" + self.model.train() + self.projection_head.train() + total_train_loss = 0.0 + batch_count = 0 - # Save training history - losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') - self.save_losses(losses_path) + train_batches = list(enumerate(train_loader)) + for i, batch_data in tqdm(train_batches[skip_batches:], + initial=skip_batches, + total=len(train_batches)): + if batch_data is None: + continue + current_batch = i + 1 + self._handle_checkpoints(current_batch) + + self.optimizer.zero_grad() + batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) + + if batch_loss > 0: + batch_loss.backward() + self.optimizer.step() + total_train_loss += batch_loss.item() + batch_count += 1 + + return total_train_loss, batch_count + + def _handle_checkpoints(self, current_batch): + """Handle checkpoint saving logic.""" + current_time = time.time() + current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time + today = current_dt.date() + + if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval: + checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt" + checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name) + + # Track and maintain only 3 recent checkpoints + self.recent_checkpoints.append(checkpoint_path) + if len(self.recent_checkpoints) > 3: + oldest = self.recent_checkpoints.pop(0) + if os.path.exists(oldest): + os.remove(oldest) + + self.last_checkpoint_time = current_time + print(f"Checkpoint saved at {checkpoint_path}") + + # Special 22:00 checkpoint (considering it's currently 10:15 PM) + is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15 + + if self.checkpoint_dir and is_22_oclock and self.last_22_date != today: + checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" + checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name) + self.last_22_date = today + print(f"22:00 Checkpoint saved at {checkpoint_path}") + + def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate): + """Handle the validation phase.""" + self.model.eval() + self.projection_head.eval() + return self._validate(val_loader, criterion_denoise, criterion_rotate) def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" From 78d88068a0fd6192f783fd3bb36ede8623bfc069 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 16 Apr 2025 22:58:59 +0200 Subject: [PATCH 18/23] updated tests to work with new pretrainer --- tests/pretrain/test_pretrainer.py | 347 ++++++++++++++++-------------- 1 file changed, 183 insertions(+), 164 deletions(-) diff --git a/tests/pretrain/test_pretrainer.py b/tests/pretrain/test_pretrainer.py index db4c73f..209fd9e 100644 --- a/tests/pretrain/test_pretrainer.py +++ b/tests/pretrain/test_pretrainer.py @@ -3,6 +3,8 @@ import torch from unittest.mock import MagicMock, patch, MagicMock, mock_open, call from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig import pandas as pd +import os +import datetime # Test the ProjectionHead class def test_projection_head(): @@ -53,11 +55,94 @@ def test_process_batch(mock_process_batch): loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate) assert loss == 0.5 +# Error cases +# New tests for checkpoint handling +@patch('torch.save') +@patch('os.path.join') +def test_save_checkpoint(mock_join, mock_save): + """Test checkpoint saving functionality.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + checkpoint_dir = "checkpoints" + epoch = 1 + batch_count = 100 + checkpoint_name = "test_checkpoint.pt" + + mock_join.return_value = os.path.join(checkpoint_dir, checkpoint_name) + + path = pretrainer._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) + + assert path == os.path.join(checkpoint_dir, checkpoint_name) + mock_save.assert_called_once() + +@patch('os.makedirs') +@patch('os.path.exists') +@patch('torch.load') +def test_load_checkpoint_specific(mock_load, mock_exists, mock_makedirs): + """Test loading a specific checkpoint.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + checkpoint_dir = "checkpoints" + specific_checkpoint = "specific_checkpoint.pt" + mock_exists.return_value = True + + mock_load.return_value = { + 'epoch': 2, + 'batch': 150, + 'model_state_dict': {}, + 'projection_head_state_dict': {}, + 'optimizer_state_dict': {}, + 'train_losses': [], + 'val_losses': [] + } + + result = pretrainer.load_checkpoint(checkpoint_dir, specific_checkpoint) + assert result == (2, 150) + +@patch('os.listdir') +@patch('os.path.getmtime') +def test_load_checkpoint_most_recent(mock_getmtime, mock_listdir): + """Test loading the most recent checkpoint.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + checkpoint_dir = "checkpoints" + mock_listdir.return_value = ["checkpoint_1.pt", "checkpoint_2.pt"] + mock_getmtime.side_effect = [100, 200] # checkpoint_2.pt is more recent + + with patch.object(pretrainer, '_load_checkpoint_file', return_value=(2, 150)): + result = pretrainer.load_checkpoint(checkpoint_dir) + assert result == (2, 150) + +def test_initialize_checkpoint_variables(): + """Test initialization of checkpoint variables.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer._initialize_checkpoint_variables() + + assert hasattr(pretrainer, 'last_checkpoint_time') + assert pretrainer.checkpoint_interval == 2 * 60 * 60 + assert pretrainer.last_22_date is None + assert pretrainer.recent_checkpoints == [] + +@patch('torch.nn.MSELoss') +@patch('torch.nn.CrossEntropyLoss') +def test_initialize_loss_functions(mock_ce_loss, mock_mse_loss): + """Test loss function initialization.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + criterion_denoise, criterion_rotate, best_val_loss = pretrainer._initialize_loss_functions() + + assert mock_mse_loss.called + assert mock_ce_loss.called + assert best_val_loss == float('inf') + @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('os.path.join', return_value='mocked/path/model.pt') -@patch('builtins.print') # Add this to mock the print function +@patch('builtins.print') def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat): """Test the train method under normal conditions with comprehensive verification.""" # Setup test data and mocks @@ -73,6 +158,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) pretrainer.projection_head = mock_projection_head pretrainer.optimizer = MagicMock() + pretrainer.checkpoint_dir = None # Initialize checkpoint_dir # Setup dataset paths and mock batch data dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] @@ -104,185 +190,118 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea assert mock_process_batch.call_count == 2 assert mock_validate.call_count == 2 - # Check for "Best model saved!" instead of model.save() mock_print.assert_any_call("Best model saved!") - mock_save_losses.assert_called_once() - # Verify state changes assert len(pretrainer.train_losses) == 2 assert pretrainer.train_losses == [0.5, 0.5] - -# Error cases -def test_train_no_dataset_paths(): - """Test ValueError when no dataset paths are provided.""" +@patch('datetime.datetime') +@patch('time.time') +def test_handle_checkpoints(mock_time, mock_datetime): + """Test checkpoint handling logic.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.checkpoint_dir = "checkpoints" + pretrainer.current_epoch = 1 + pretrainer._initialize_checkpoint_variables() + + # Set a base time value + base_time = 1000 + # Set the last checkpoint time to base_time + pretrainer.last_checkpoint_time = base_time + + # Mock time to return base_time + interval + 1 to trigger checkpoint save + mock_time.return_value = base_time + pretrainer.checkpoint_interval + 1 + + # Mock datetime for 22:00 checkpoint + mock_dt = MagicMock() + mock_dt.hour = 22 + mock_dt.minute = 0 + mock_dt.date.return_value = datetime.date(2023, 1, 1) + mock_datetime.now.return_value = mock_dt + + with patch.object(pretrainer, '_save_checkpoint') as mock_save: + pretrainer._handle_checkpoints(100) + # Should be called twice - once for regular interval and once for 22:00 + assert mock_save.call_count == 2 - with pytest.raises(ValueError, match="No dataset paths provided"): - pretrainer.train([]) - -@patch('pandas.read_parquet') -def test_train_all_datasets_fail(mock_read_parquet): - """Test handling when all datasets fail to load.""" - mock_read_parquet.side_effect = Exception("Failed to load dataset") - +def test_training_phase(): + """Test the training phase logic.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) - dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] - - with pytest.raises(ValueError, match="No valid datasets could be loaded"): - pretrainer.train(dataset_paths) - -# Edge cases -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat): - """Test behavior with empty data loaders.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - loader_instance = MagicMock() - loader_instance.train_loader = [] # Empty train loader - loader_instance.val_loader = [] # Empty val loader - mock_data_loader.return_value = loader_instance - - mock_model = MagicMock() - pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) - pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() - - with patch.object(Pretrainer, 'save_losses') as mock_save_losses: - pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) - - # Verify empty loader behavior - assert len(pretrainer.train_losses) == 1 - assert pretrainer.train_losses[0] == 0.0 - mock_save_losses.assert_called_once() - -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat): - """Test behavior when batch_data is None.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - loader_instance = MagicMock() - loader_instance.train_loader = [None] # Loader returns None - loader_instance.val_loader = [] - mock_data_loader.return_value = loader_instance - - pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) - pretrainer.projection_head = MagicMock() - pretrainer.optimizer = MagicMock() - - with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \ - patch.object(Pretrainer, 'save_losses'): - pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) - - # Verify None batch handling - mock_process_batch.assert_not_called() - assert pretrainer.train_losses[0] == 0.0 - -# Parameter variations -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat): - """Test that custom parameters are properly passed through.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - loader_instance = MagicMock() - loader_instance.train_loader = [] - loader_instance.val_loader = [] - mock_data_loader.return_value = loader_instance - - pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) - pretrainer.projection_head = MagicMock() - pretrainer.optimizer = MagicMock() - - # Custom parameters - custom_output_path = "custom/output/path" - custom_column = "custom_column" - custom_batch_size = 16 - custom_sample_size = 5000 - - with patch.object(Pretrainer, 'save_losses'): - pretrainer.train( - ['path/to/dataset.parquet'], - output_path=custom_output_path, - column=custom_column, - batch_size=custom_batch_size, - sample_size=custom_sample_size - ) - - # Verify custom parameters were used - mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size) - assert mock_data_loader.call_args[1]['column'] == custom_column - assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size - - - -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -@patch('builtins.print') # Add this to mock the print function -def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat): - """Test that model is saved only when validation loss improves.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - # Create mock batch data with proper structure + pretrainer.checkpoint_dir = None # Initialize checkpoint_dir + pretrainer._initialize_checkpoint_variables() + pretrainer.current_epoch = 0 + + # Create mock batch data with requires_grad=True mock_batch_data = { - 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), - 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) + 'denoise': ( + torch.randn(2, 3, 32, 32, requires_grad=True), + torch.randn(2, 3, 32, 32, requires_grad=True) + ), + 'rotate': ( + torch.randn(2, 3, 32, 32, requires_grad=True), + torch.tensor([0, 1], dtype=torch.long) # Labels typically don't need gradients + ) } + mock_train_loader = [(0, mock_batch_data)] # Include batch index + + # Mock the loss functions to return tensors that require gradients + criterion_denoise = MagicMock(return_value=torch.tensor(0.5, requires_grad=True)) + criterion_rotate = MagicMock(return_value=torch.tensor(0.5, requires_grad=True)) + + with patch.object(pretrainer, '_process_batch', return_value=torch.tensor(0.5, requires_grad=True)), \ + patch.object(pretrainer, '_handle_checkpoints') as mock_handle_checkpoints: + + total_loss, batch_count = pretrainer._training_phase( + mock_train_loader, 0, criterion_denoise, criterion_rotate) + + assert total_loss == 0.5 + assert batch_count == 1 + mock_handle_checkpoints.assert_called_once_with(1) # Check if checkpoint handling was called - loader_instance = MagicMock() - loader_instance.train_loader = [mock_batch_data] - loader_instance.val_loader = [mock_batch_data] - mock_data_loader.return_value = loader_instance - - mock_model = MagicMock() - pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) +def test_validation_phase(): + """Test the validation phase logic.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.projection_head = MagicMock() - pretrainer.optimizer = MagicMock() + + mock_val_loader = [MagicMock()] + criterion_denoise = MagicMock() + criterion_rotate = MagicMock() + + with patch.object(pretrainer, '_validate', return_value=0.4): + val_loss = pretrainer._validation_phase( + mock_val_loader, criterion_denoise, criterion_rotate) + + assert val_loss == 0.4 - # Initialize the best validation loss - pretrainer.best_val_loss = float('inf') +@patch('pandas.read_parquet') +def test_load_and_merge_datasets(mock_read_parquet): + """Test dataset loading and merging.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + mock_df = pd.DataFrame({'col': [1, 2, 3]}) + mock_read_parquet.return_value.head.return_value = mock_df + + result = pretrainer._load_and_merge_datasets(['path1.parquet', 'path2.parquet'], 1000) + assert len(result) == 6 # 2 datasets * 3 rows each - mock_batch_loss = torch.tensor(0.5, requires_grad=True) - - # Test improving validation loss - with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ - patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \ - patch.object(Pretrainer, 'save_losses'): - pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - - # Check for "Best model saved!" 3 times - assert mock_print.call_args_list.count(call("Best model saved!")) == 3 - - # Reset for next test - mock_print.reset_mock() - pretrainer.train_losses = [] - - # Reset best validation loss for the second test - pretrainer.best_val_loss = float('inf') - - # Test fluctuating validation loss - with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ - patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \ - patch.object(Pretrainer, 'save_losses'): - pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - - # Should print "Best model saved!" only on first and third epochs - assert mock_print.call_args_list.count(call("Best model saved!")) == 2 +def test_process_batch_none_tasks(): + """Test processing batch with no tasks.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + batch_data = { + 'denoise': None, + 'rotate': None + } + + loss = pretrainer._process_batch( + batch_data, + criterion_denoise=MagicMock(), + criterion_rotate=MagicMock() + ) + + assert loss == 0 @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') From 7d24de1f7e5cacb1ed7bc0fd049053103f24ca56 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 16 Apr 2025 22:59:13 +0200 Subject: [PATCH 19/23] updated example usage to feature the checkpoint handling --- example.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/example.py b/example.py index 8dbb67e..7be0e5a 100644 --- a/example.py +++ b/example.py @@ -6,19 +6,15 @@ from aiia.pretrain import Pretrainer config = AIIAConfig(model_name="AIIA-Base-512x20k") model = AIIABase(config) -# Initialize pretrainer with the model pretrainer = Pretrainer(model, learning_rate=1e-4, config=config) -# List of dataset paths -dataset_paths = [ - "/path/to/dataset1.parquet", - "/path/to/dataset2.parquet" -] +# Set checkpoint directory +checkpoint_dir = "checkpoints/my_model" -# Start training with multiple datasets +# Start training (will automatically load checkpoint if available) pretrainer.train( - dataset_paths=dataset_paths, - num_epochs=10, - batch_size=2, - sample_size=10000 -) \ No newline at end of file + dataset_paths=["path/to/dataset1.parquet", "path/to/dataset2.parquet"], + output_path="trained_models/my_model", + checkpoint_dir=checkpoint_dir, + num_epochs=10 +) From c702834cee777205010976e68167cda8a0e3012f Mon Sep 17 00:00:00 2001 From: Falko Victor Habel Date: Thu, 17 Apr 2025 10:51:29 +0000 Subject: [PATCH 20/23] manual mergerequest --- src/aiia/pretrain/pretrainer.py | 288 ++++++++++++++++++++++++++------ 1 file changed, 233 insertions(+), 55 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 93df9dd..f94af2c 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -1,9 +1,11 @@ import torch from torch import nn import csv +import datetime +import time import pandas as pd from tqdm import tqdm -from transformers import PreTrainedModel +from ..model.Model import AIIA from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader import os @@ -21,12 +23,12 @@ class ProjectionHead(nn.Module): return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: - def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): + def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. Args: - model (PreTrainedModel): The model instance to pretrain + model (AIIA): The model instance to pretrain learning_rate (float): Learning rate for optimization config (dict): Model configuration containing hidden_size """ @@ -112,20 +114,169 @@ class Pretrainer: return batch_loss - def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): - """ - Train the model using multiple specified datasets. + def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name): + """Save a model checkpoint. Args: - dataset_paths (list): List of paths to parquet datasets - num_epochs (int): Number of training epochs - batch_size (int): Batch size for training - sample_size (int): Number of samples to use from each dataset + checkpoint_dir (str): Directory to save the checkpoint + epoch (int): Current epoch number + batch_count (int): Current batch count + checkpoint_name (str): Name for the checkpoint file + + Returns: + str: Path to the saved checkpoint """ + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + checkpoint_data = { + 'epoch': epoch + 1, + 'batch': batch_count, + 'model_state_dict': self.model.state_dict(), + 'projection_head_state_dict': self.projection_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': self.train_losses, + 'val_losses': self.val_losses, + } + torch.save(checkpoint_data, checkpoint_path) + return checkpoint_path + + def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None): + """ + Check for checkpoints and load if available. + + Args: + checkpoint_dir (str): Directory where checkpoints are stored + specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent. + + Returns: + tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise + """ + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + + # If a specific checkpoint is requested + if specific_checkpoint: + checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint) + if os.path.exists(checkpoint_path): + return self._load_checkpoint_file(checkpoint_path) + else: + print(f"Specified checkpoint {specific_checkpoint} not found.") + return None + + # Find all checkpoint files + checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")] + + if not checkpoint_files: + print("No checkpoints found in directory.") + return None + + # Find the most recent checkpoint + checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) + most_recent = checkpoint_files[0] + checkpoint_path = os.path.join(checkpoint_dir, most_recent) + + return self._load_checkpoint_file(checkpoint_path) + + def _load_checkpoint_file(self, checkpoint_path): + """ + Load a specific checkpoint file. + + Args: + checkpoint_path (str): Path to the checkpoint file + + Returns: + tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise + """ + try: + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + # Load model state + self.model.load_state_dict(checkpoint['model_state_dict']) + + # Load projection head state + self.projection_head.load_state_dict(checkpoint['projection_head_state_dict']) + + # Load optimizer state + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + # Load loss history + self.train_losses = checkpoint.get('train_losses', []) + self.val_losses = checkpoint.get('val_losses', []) + + loaded_epoch = checkpoint['epoch'] + loaded_batch = checkpoint['batch'] + + print(f"Checkpoint loaded from {checkpoint_path}") + print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}") + + return loaded_epoch, loaded_batch + + except Exception as e: + print(f"Error loading checkpoint: {e}") + return None + + + def train(self, dataset_paths, output_path="AIIA", column="image_bytes", + num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): + """Train the model using multiple specified datasets with checkpoint resumption support.""" if not dataset_paths: raise ValueError("No dataset paths provided") - # Read and merge all datasets + self._initialize_checkpoint_variables() + start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir) + + dataframes = self._load_and_merge_datasets(dataset_paths, sample_size) + aiia_loader = self._initialize_data_loader(dataframes, column, batch_size) + + criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions() + + for epoch in range(start_epoch, num_epochs): + print(f"\nEpoch {epoch+1}/{num_epochs}") + print("-" * 20) + total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader, + start_batch if (epoch == start_epoch and resume_training) else 0, + criterion_denoise, + criterion_rotate) + + avg_train_loss = total_train_loss / max(batch_count, 1) + self.train_losses.append(avg_train_loss) + print(f"Training Loss: {avg_train_loss:.4f}") + + val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate) + + if val_loss < best_val_loss: + best_val_loss = val_loss + self.model.save(output_path) + print("Best model saved!") + + losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') + self.save_losses(losses_path) + + def _initialize_checkpoint_variables(self): + """Initialize checkpoint tracking variables.""" + self.last_checkpoint_time = time.time() + self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds + self.last_22_date = None + self.recent_checkpoints = [] + + def _load_checkpoints(self, checkpoint_dir): + """Load checkpoints and return start epoch, batch, and resumption flag.""" + start_epoch = 0 + start_batch = 0 + resume_training = False + + if checkpoint_dir is not None: + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_info = self.load_checkpoint(checkpoint_dir) + if checkpoint_info: + start_epoch, start_batch = checkpoint_info + resume_training = True + # Adjust epoch to be 0-indexed for the loop + start_epoch -= 1 + + return start_epoch, start_batch, resume_training + + def _load_and_merge_datasets(self, dataset_paths, sample_size): + """Load and merge datasets.""" dataframes = [] for path in dataset_paths: try: @@ -133,14 +284,15 @@ class Pretrainer: dataframes.append(df) except Exception as e: print(f"Error loading dataset {path}: {e}") - + if not dataframes: raise ValueError("No valid datasets could be loaded") - - merged_df = pd.concat(dataframes, ignore_index=True) - # Initialize data loader - aiia_loader = AIIADataLoader( + return pd.concat(dataframes, ignore_index=True) + + def _initialize_data_loader(self, merged_df, column, batch_size): + """Initialize the data loader.""" + return AIIADataLoader( merged_df, column=column, batch_size=batch_size, @@ -148,49 +300,75 @@ class Pretrainer: collate_fn=self.safe_collate ) + def _initialize_loss_functions(self): + """Initialize loss functions and tracking variables.""" criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() best_val_loss = float('inf') + return criterion_denoise, criterion_rotate, best_val_loss - for epoch in range(num_epochs): - print(f"\nEpoch {epoch+1}/{num_epochs}") - print("-" * 20) - - # Training phase - self.model.train() - self.projection_head.train() - total_train_loss = 0.0 - batch_count = 0 - - for batch_data in tqdm(aiia_loader.train_loader): - if batch_data is None: - continue - - self.optimizer.zero_grad() - batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) - - if batch_loss > 0: - batch_loss.backward() - self.optimizer.step() - total_train_loss += batch_loss.item() - batch_count += 1 - - avg_train_loss = total_train_loss / max(batch_count, 1) - self.train_losses.append(avg_train_loss) - print(f"Training Loss: {avg_train_loss:.4f}") + def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate): + """Handle the training phase.""" + self.model.train() + self.projection_head.train() + total_train_loss = 0.0 + batch_count = 0 - # Validation phase - self.model.eval() - self.projection_head.eval() - val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) - - if val_loss < best_val_loss: - best_val_loss = val_loss - self.model.save_pretrained(output_path) - print("Best model save_pretrainedd!") + train_batches = list(enumerate(train_loader)) + for i, batch_data in tqdm(train_batches[skip_batches:], + initial=skip_batches, + total=len(train_batches)): + if batch_data is None: + continue - losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') - self.save_pretrained_losses(losses_path) + current_batch = i + 1 + self._handle_checkpoints(current_batch) + + self.optimizer.zero_grad() + batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) + + if batch_loss > 0: + batch_loss.backward() + self.optimizer.step() + total_train_loss += batch_loss.item() + batch_count += 1 + + return total_train_loss, batch_count + + def _handle_checkpoints(self, current_batch): + """Handle checkpoint saving logic.""" + current_time = time.time() + current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time + today = current_dt.date() + + if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval: + checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt" + checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name) + + # Track and maintain only 3 recent checkpoints + self.recent_checkpoints.append(checkpoint_path) + if len(self.recent_checkpoints) > 3: + oldest = self.recent_checkpoints.pop(0) + if os.path.exists(oldest): + os.remove(oldest) + + self.last_checkpoint_time = current_time + print(f"Checkpoint saved at {checkpoint_path}") + + # Special 22:00 checkpoint (considering it's currently 10:15 PM) + is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15 + + if self.checkpoint_dir and is_22_oclock and self.last_22_date != today: + checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" + checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name) + self.last_22_date = today + print(f"22:00 Checkpoint saved at {checkpoint_path}") + + def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate): + """Handle the validation phase.""" + self.model.eval() + self.projection_head.eval() + return self._validate(val_loader, criterion_denoise, criterion_rotate) def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" @@ -216,8 +394,8 @@ class Pretrainer: return avg_val_loss - def save_pretrained_losses(self, csv_file): - """save_pretrained training and validation losses to a CSV file.""" + def save_losses(self, csv_file): + """Save training and validation losses to a CSV file.""" data = list(zip( range(1, len(self.train_losses) + 1), self.train_losses, From dc8290fddf1de7b643e8f5200f7bf825b3add368 Mon Sep 17 00:00:00 2001 From: Falko Victor Habel Date: Thu, 17 Apr 2025 10:52:03 +0000 Subject: [PATCH 21/23] manual pr --- tests/pretrain/test_pretrainer.py | 363 ++++++++++++++++-------------- 1 file changed, 191 insertions(+), 172 deletions(-) diff --git a/tests/pretrain/test_pretrainer.py b/tests/pretrain/test_pretrainer.py index 8f4283c..209fd9e 100644 --- a/tests/pretrain/test_pretrainer.py +++ b/tests/pretrain/test_pretrainer.py @@ -3,6 +3,8 @@ import torch from unittest.mock import MagicMock, patch, MagicMock, mock_open, call from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig import pandas as pd +import os +import datetime # Test the ProjectionHead class def test_projection_head(): @@ -53,11 +55,94 @@ def test_process_batch(mock_process_batch): loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate) assert loss == 0.5 +# Error cases +# New tests for checkpoint handling +@patch('torch.save') +@patch('os.path.join') +def test_save_checkpoint(mock_join, mock_save): + """Test checkpoint saving functionality.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + checkpoint_dir = "checkpoints" + epoch = 1 + batch_count = 100 + checkpoint_name = "test_checkpoint.pt" + + mock_join.return_value = os.path.join(checkpoint_dir, checkpoint_name) + + path = pretrainer._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) + + assert path == os.path.join(checkpoint_dir, checkpoint_name) + mock_save.assert_called_once() + +@patch('os.makedirs') +@patch('os.path.exists') +@patch('torch.load') +def test_load_checkpoint_specific(mock_load, mock_exists, mock_makedirs): + """Test loading a specific checkpoint.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + checkpoint_dir = "checkpoints" + specific_checkpoint = "specific_checkpoint.pt" + mock_exists.return_value = True + + mock_load.return_value = { + 'epoch': 2, + 'batch': 150, + 'model_state_dict': {}, + 'projection_head_state_dict': {}, + 'optimizer_state_dict': {}, + 'train_losses': [], + 'val_losses': [] + } + + result = pretrainer.load_checkpoint(checkpoint_dir, specific_checkpoint) + assert result == (2, 150) + +@patch('os.listdir') +@patch('os.path.getmtime') +def test_load_checkpoint_most_recent(mock_getmtime, mock_listdir): + """Test loading the most recent checkpoint.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + checkpoint_dir = "checkpoints" + mock_listdir.return_value = ["checkpoint_1.pt", "checkpoint_2.pt"] + mock_getmtime.side_effect = [100, 200] # checkpoint_2.pt is more recent + + with patch.object(pretrainer, '_load_checkpoint_file', return_value=(2, 150)): + result = pretrainer.load_checkpoint(checkpoint_dir) + assert result == (2, 150) + +def test_initialize_checkpoint_variables(): + """Test initialization of checkpoint variables.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer._initialize_checkpoint_variables() + + assert hasattr(pretrainer, 'last_checkpoint_time') + assert pretrainer.checkpoint_interval == 2 * 60 * 60 + assert pretrainer.last_22_date is None + assert pretrainer.recent_checkpoints == [] + +@patch('torch.nn.MSELoss') +@patch('torch.nn.CrossEntropyLoss') +def test_initialize_loss_functions(mock_ce_loss, mock_mse_loss): + """Test loss function initialization.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + criterion_denoise, criterion_rotate, best_val_loss = pretrainer._initialize_loss_functions() + + assert mock_mse_loss.called + assert mock_ce_loss.called + assert best_val_loss == float('inf') + @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('os.path.join', return_value='mocked/path/model.pt') -@patch('builtins.print') # Add this to mock the print function +@patch('builtins.print') def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat): """Test the train method under normal conditions with comprehensive verification.""" # Setup test data and mocks @@ -73,6 +158,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) pretrainer.projection_head = mock_projection_head pretrainer.optimizer = MagicMock() + pretrainer.checkpoint_dir = None # Initialize checkpoint_dir # Setup dataset paths and mock batch data dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] @@ -94,7 +180,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea # Execute training with patched methods with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \ patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \ - patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses, \ + patch.object(Pretrainer, 'save_losses') as mock_save_losses, \ patch('builtins.open', mock_open()): pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2) @@ -104,185 +190,118 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea assert mock_process_batch.call_count == 2 assert mock_validate.call_count == 2 - # Check for "Best model save_pretrainedd!" instead of model.save_pretrained() - mock_print.assert_any_call("Best model save_pretrainedd!") - - mock_save_pretrained_losses.assert_called_once() + mock_print.assert_any_call("Best model saved!") + mock_save_losses.assert_called_once() - # Verify state changes assert len(pretrainer.train_losses) == 2 assert pretrainer.train_losses == [0.5, 0.5] - -# Error cases -def test_train_no_dataset_paths(): - """Test ValueError when no dataset paths are provided.""" +@patch('datetime.datetime') +@patch('time.time') +def test_handle_checkpoints(mock_time, mock_datetime): + """Test checkpoint handling logic.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.checkpoint_dir = "checkpoints" + pretrainer.current_epoch = 1 + pretrainer._initialize_checkpoint_variables() + + # Set a base time value + base_time = 1000 + # Set the last checkpoint time to base_time + pretrainer.last_checkpoint_time = base_time + + # Mock time to return base_time + interval + 1 to trigger checkpoint save + mock_time.return_value = base_time + pretrainer.checkpoint_interval + 1 + + # Mock datetime for 22:00 checkpoint + mock_dt = MagicMock() + mock_dt.hour = 22 + mock_dt.minute = 0 + mock_dt.date.return_value = datetime.date(2023, 1, 1) + mock_datetime.now.return_value = mock_dt + + with patch.object(pretrainer, '_save_checkpoint') as mock_save: + pretrainer._handle_checkpoints(100) + # Should be called twice - once for regular interval and once for 22:00 + assert mock_save.call_count == 2 - with pytest.raises(ValueError, match="No dataset paths provided"): - pretrainer.train([]) - -@patch('pandas.read_parquet') -def test_train_all_datasets_fail(mock_read_parquet): - """Test handling when all datasets fail to load.""" - mock_read_parquet.side_effect = Exception("Failed to load dataset") - +def test_training_phase(): + """Test the training phase logic.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) - dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] - - with pytest.raises(ValueError, match="No valid datasets could be loaded"): - pretrainer.train(dataset_paths) - -# Edge cases -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat): - """Test behavior with empty data loaders.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - loader_instance = MagicMock() - loader_instance.train_loader = [] # Empty train loader - loader_instance.val_loader = [] # Empty val loader - mock_data_loader.return_value = loader_instance - - mock_model = MagicMock() - pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) - pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() - - with patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses: - pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) - - # Verify empty loader behavior - assert len(pretrainer.train_losses) == 1 - assert pretrainer.train_losses[0] == 0.0 - mock_save_pretrained_losses.assert_called_once() - -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat): - """Test behavior when batch_data is None.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - loader_instance = MagicMock() - loader_instance.train_loader = [None] # Loader returns None - loader_instance.val_loader = [] - mock_data_loader.return_value = loader_instance - - pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) - pretrainer.projection_head = MagicMock() - pretrainer.optimizer = MagicMock() - - with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \ - patch.object(Pretrainer, 'save_pretrained_losses'): - pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) - - # Verify None batch handling - mock_process_batch.assert_not_called() - assert pretrainer.train_losses[0] == 0.0 - -# Parameter variations -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat): - """Test that custom parameters are properly passed through.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - loader_instance = MagicMock() - loader_instance.train_loader = [] - loader_instance.val_loader = [] - mock_data_loader.return_value = loader_instance - - pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) - pretrainer.projection_head = MagicMock() - pretrainer.optimizer = MagicMock() - - # Custom parameters - custom_output_path = "custom/output/path" - custom_column = "custom_column" - custom_batch_size = 16 - custom_sample_size = 5000 - - with patch.object(Pretrainer, 'save_pretrained_losses'): - pretrainer.train( - ['path/to/dataset.parquet'], - output_path=custom_output_path, - column=custom_column, - batch_size=custom_batch_size, - sample_size=custom_sample_size - ) - - # Verify custom parameters were used - mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size) - assert mock_data_loader.call_args[1]['column'] == custom_column - assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size - - - -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -@patch('builtins.print') # Add this to mock the print function -def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat): - """Test that model is save_pretrainedd only when validation loss improves.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - # Create mock batch data with proper structure + pretrainer.checkpoint_dir = None # Initialize checkpoint_dir + pretrainer._initialize_checkpoint_variables() + pretrainer.current_epoch = 0 + + # Create mock batch data with requires_grad=True mock_batch_data = { - 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), - 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) + 'denoise': ( + torch.randn(2, 3, 32, 32, requires_grad=True), + torch.randn(2, 3, 32, 32, requires_grad=True) + ), + 'rotate': ( + torch.randn(2, 3, 32, 32, requires_grad=True), + torch.tensor([0, 1], dtype=torch.long) # Labels typically don't need gradients + ) } + mock_train_loader = [(0, mock_batch_data)] # Include batch index + + # Mock the loss functions to return tensors that require gradients + criterion_denoise = MagicMock(return_value=torch.tensor(0.5, requires_grad=True)) + criterion_rotate = MagicMock(return_value=torch.tensor(0.5, requires_grad=True)) + + with patch.object(pretrainer, '_process_batch', return_value=torch.tensor(0.5, requires_grad=True)), \ + patch.object(pretrainer, '_handle_checkpoints') as mock_handle_checkpoints: + + total_loss, batch_count = pretrainer._training_phase( + mock_train_loader, 0, criterion_denoise, criterion_rotate) + + assert total_loss == 0.5 + assert batch_count == 1 + mock_handle_checkpoints.assert_called_once_with(1) # Check if checkpoint handling was called - loader_instance = MagicMock() - loader_instance.train_loader = [mock_batch_data] - loader_instance.val_loader = [mock_batch_data] - mock_data_loader.return_value = loader_instance - - mock_model = MagicMock() - pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) +def test_validation_phase(): + """Test the validation phase logic.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.projection_head = MagicMock() - pretrainer.optimizer = MagicMock() + + mock_val_loader = [MagicMock()] + criterion_denoise = MagicMock() + criterion_rotate = MagicMock() + + with patch.object(pretrainer, '_validate', return_value=0.4): + val_loss = pretrainer._validation_phase( + mock_val_loader, criterion_denoise, criterion_rotate) + + assert val_loss == 0.4 - # Initialize the best validation loss - pretrainer.best_val_loss = float('inf') +@patch('pandas.read_parquet') +def test_load_and_merge_datasets(mock_read_parquet): + """Test dataset loading and merging.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + mock_df = pd.DataFrame({'col': [1, 2, 3]}) + mock_read_parquet.return_value.head.return_value = mock_df + + result = pretrainer._load_and_merge_datasets(['path1.parquet', 'path2.parquet'], 1000) + assert len(result) == 6 # 2 datasets * 3 rows each - mock_batch_loss = torch.tensor(0.5, requires_grad=True) - - # Test improving validation loss - with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ - patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \ - patch.object(Pretrainer, 'save_pretrained_losses'): - pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - - # Check for "Best model save_pretrainedd!" 3 times - assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 3 - - # Reset for next test - mock_print.reset_mock() - pretrainer.train_losses = [] - - # Reset best validation loss for the second test - pretrainer.best_val_loss = float('inf') - - # Test fluctuating validation loss - with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ - patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \ - patch.object(Pretrainer, 'save_pretrained_losses'): - pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - - # Should print "Best model save_pretrainedd!" only on first and third epochs - assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 2 +def test_process_batch_none_tasks(): + """Test processing batch with no tasks.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + batch_data = { + 'denoise': None, + 'rotate': None + } + + loss = pretrainer._process_batch( + batch_data, + criterion_denoise=MagicMock(), + criterion_rotate=MagicMock() + ) + + assert loss == 0 @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') @@ -296,13 +315,13 @@ def test_validate(mock_process_batch): loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate) assert loss == 0.5 -# Test the save_pretrained_losses method -@patch('aiia.pretrain.pretrainer.Pretrainer.save_pretrained_losses') -def test_save_pretrained_losses(mock_save_pretrained_losses): +# Test the save_losses method +@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses') +def test_save_losses(mock_save_losses): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) pretrainer.train_losses = [0.1, 0.2] pretrainer.val_losses = [0.3, 0.4] csv_file = 'losses.csv' - pretrainer.save_pretrained_losses(csv_file) - mock_save_pretrained_losses.assert_called_once_with(csv_file) \ No newline at end of file + pretrainer.save_losses(csv_file) + mock_save_losses.assert_called_once_with(csv_file) \ No newline at end of file From 9b39a6926552fd2124585f4b971b509f00871ddd Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Thu, 17 Apr 2025 13:04:23 +0200 Subject: [PATCH 22/23] updated pretrainer to work with correct imports --- src/aiia/pretrain/pretrainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index f94af2c..6815a90 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -5,7 +5,7 @@ import datetime import time import pandas as pd from tqdm import tqdm -from ..model.Model import AIIA +from transformers import PreTrainedModel from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader import os @@ -23,7 +23,7 @@ class ProjectionHead(nn.Module): return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: - def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): + def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. From d20ac8cbee63819b0a8afc4e841652b1db23ff44 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Thu, 17 Apr 2025 13:28:53 +0200 Subject: [PATCH 23/23] corrected version numbering --- pyproject.toml | 2 +- setup.cfg | 2 +- src/aiia/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2220a98..ff68c3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = '\.pyi?$' [project] name = "aiia" -version = "0.3.0" +version = "0.3.1" description = "AIIA Deep Learning Model Implementation" readme = "README.md" authors = [ diff --git a/setup.cfg b/setup.cfg index 26614c8..598b087 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiia -version = 0.3.0 +version = 0.3.1 author = Falko Habel author_email = falko.habel@gmx.de description = AIIA deep learning model implementation diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index 699a42f..c1bba6c 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead -__version__ = "0.3.0" +__version__ = "0.3.1"