From a369c49f15eb1213cdf2dd952d12144e0fc62d82 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 28 Jan 2025 11:27:42 +0100 Subject: [PATCH] improved pretraining --- README.md | 37 +++++++++++++++++++++------------ example.py | 27 ++++++++++++++++++++++++ src/aiia/pretrain/pretrainer.py | 29 ++++++++++++++++++-------- 3 files changed, 71 insertions(+), 22 deletions(-) create mode 100644 example.py diff --git a/README.md b/README.md index 830f111..6f149b1 100644 --- a/README.md +++ b/README.md @@ -3,17 +3,28 @@ ## Example Usage: ```Python -if __name__ == "__main__": - data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet" - data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet" - - from aiia.model import AIIABase - from aiia.model.config import AIIAConfig - from aiia.pretrain import Pretrainer - - config = AIIAConfig(model_name="AIIA-Base-512x20k") - model = AIIABase(config) - - pretrainer = Pretrainer(model, learning_rate=1e-4) - pretrainer.train(data_path1, data_path2, num_epochs=10) +from aiia.model import AIIABase +from aiia.model.config import AIIAConfig +from aiia.pretrain import Pretrainer + +# Create your model +config = AIIAConfig(model_name="AIIA-Base-512x20k") +model = AIIABase(config) + +# Initialize pretrainer with the model +pretrainer = Pretrainer(model, learning_rate=1e-4) + +# List of dataset paths +dataset_paths = [ + "/path/to/dataset1.parquet", + "/path/to/dataset2.parquet" +] + +# Start training with multiple datasets +pretrainer.train( + dataset_paths=dataset_paths, + num_epochs=10, + batch_size=2, + sample_size=10000 +) ``` \ No newline at end of file diff --git a/example.py b/example.py new file mode 100644 index 0000000..6e1620b --- /dev/null +++ b/example.py @@ -0,0 +1,27 @@ +data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet" +data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet" + +from aiia.model import AIIABase +from aiia.model.config import AIIAConfig +from aiia.pretrain import Pretrainer + +# Create your model +config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, hidden_size=256) +model = AIIABase(config) + +# Initialize pretrainer with the model +pretrainer = Pretrainer(model, learning_rate=config.learning_rate) + +# List of dataset paths +dataset_paths = [ + data_path1, + data_path2 +] + +# Start training with multiple datasets +pretrainer.train( + dataset_paths=dataset_paths, + num_epochs=10, + batch_size=2, + sample_size=10000 +) \ No newline at end of file diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index b540db0..913b77a 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -108,26 +108,37 @@ class Pretrainer: return batch_loss - def train(self, data_path1, data_path2, num_epochs=3, batch_size=2, sample_size=10000): + def train(self, dataset_paths, column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): """ - Train the model using the specified datasets. + Train the model using multiple specified datasets. Args: - data_path1 (str): Path to first dataset - data_path2 (str): Path to second dataset + dataset_paths (list): List of paths to parquet datasets num_epochs (int): Number of training epochs batch_size (int): Batch size for training sample_size (int): Number of samples to use from each dataset """ - # Read and merge datasets - df1 = pd.read_parquet(data_path1).head(sample_size) - df2 = pd.read_parquet(data_path2).head(sample_size) - merged_df = pd.concat([df1, df2], ignore_index=True) + if not dataset_paths: + raise ValueError("No dataset paths provided") + + # Read and merge all datasets + dataframes = [] + for path in dataset_paths: + try: + df = pd.read_parquet(path).head(sample_size) + dataframes.append(df) + except Exception as e: + print(f"Error loading dataset {path}: {e}") + + if not dataframes: + raise ValueError("No valid datasets could be loaded") + + merged_df = pd.concat(dataframes, ignore_index=True) # Initialize data loader aiia_loader = AIIADataLoader( merged_df, - column="image_bytes", + column=column, batch_size=batch_size, pretraining=True, collate_fn=self.safe_collate