improved pretraining
This commit is contained in:
parent
7de7eef081
commit
a369c49f15
37
README.md
37
README.md
|
@ -3,17 +3,28 @@
|
||||||
|
|
||||||
## Example Usage:
|
## Example Usage:
|
||||||
```Python
|
```Python
|
||||||
if __name__ == "__main__":
|
from aiia.model import AIIABase
|
||||||
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
from aiia.model.config import AIIAConfig
|
||||||
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
from aiia.pretrain import Pretrainer
|
||||||
|
|
||||||
from aiia.model import AIIABase
|
# Create your model
|
||||||
from aiia.model.config import AIIAConfig
|
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
||||||
from aiia.pretrain import Pretrainer
|
model = AIIABase(config)
|
||||||
|
|
||||||
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
# Initialize pretrainer with the model
|
||||||
model = AIIABase(config)
|
pretrainer = Pretrainer(model, learning_rate=1e-4)
|
||||||
|
|
||||||
pretrainer = Pretrainer(model, learning_rate=1e-4)
|
# List of dataset paths
|
||||||
pretrainer.train(data_path1, data_path2, num_epochs=10)
|
dataset_paths = [
|
||||||
|
"/path/to/dataset1.parquet",
|
||||||
|
"/path/to/dataset2.parquet"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start training with multiple datasets
|
||||||
|
pretrainer.train(
|
||||||
|
dataset_paths=dataset_paths,
|
||||||
|
num_epochs=10,
|
||||||
|
batch_size=2,
|
||||||
|
sample_size=10000
|
||||||
|
)
|
||||||
```
|
```
|
|
@ -0,0 +1,27 @@
|
||||||
|
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
||||||
|
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||||
|
|
||||||
|
from aiia.model import AIIABase
|
||||||
|
from aiia.model.config import AIIAConfig
|
||||||
|
from aiia.pretrain import Pretrainer
|
||||||
|
|
||||||
|
# Create your model
|
||||||
|
config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, hidden_size=256)
|
||||||
|
model = AIIABase(config)
|
||||||
|
|
||||||
|
# Initialize pretrainer with the model
|
||||||
|
pretrainer = Pretrainer(model, learning_rate=config.learning_rate)
|
||||||
|
|
||||||
|
# List of dataset paths
|
||||||
|
dataset_paths = [
|
||||||
|
data_path1,
|
||||||
|
data_path2
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start training with multiple datasets
|
||||||
|
pretrainer.train(
|
||||||
|
dataset_paths=dataset_paths,
|
||||||
|
num_epochs=10,
|
||||||
|
batch_size=2,
|
||||||
|
sample_size=10000
|
||||||
|
)
|
|
@ -108,26 +108,37 @@ class Pretrainer:
|
||||||
|
|
||||||
return batch_loss
|
return batch_loss
|
||||||
|
|
||||||
def train(self, data_path1, data_path2, num_epochs=3, batch_size=2, sample_size=10000):
|
def train(self, dataset_paths, column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000):
|
||||||
"""
|
"""
|
||||||
Train the model using the specified datasets.
|
Train the model using multiple specified datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_path1 (str): Path to first dataset
|
dataset_paths (list): List of paths to parquet datasets
|
||||||
data_path2 (str): Path to second dataset
|
|
||||||
num_epochs (int): Number of training epochs
|
num_epochs (int): Number of training epochs
|
||||||
batch_size (int): Batch size for training
|
batch_size (int): Batch size for training
|
||||||
sample_size (int): Number of samples to use from each dataset
|
sample_size (int): Number of samples to use from each dataset
|
||||||
"""
|
"""
|
||||||
# Read and merge datasets
|
if not dataset_paths:
|
||||||
df1 = pd.read_parquet(data_path1).head(sample_size)
|
raise ValueError("No dataset paths provided")
|
||||||
df2 = pd.read_parquet(data_path2).head(sample_size)
|
|
||||||
merged_df = pd.concat([df1, df2], ignore_index=True)
|
# Read and merge all datasets
|
||||||
|
dataframes = []
|
||||||
|
for path in dataset_paths:
|
||||||
|
try:
|
||||||
|
df = pd.read_parquet(path).head(sample_size)
|
||||||
|
dataframes.append(df)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading dataset {path}: {e}")
|
||||||
|
|
||||||
|
if not dataframes:
|
||||||
|
raise ValueError("No valid datasets could be loaded")
|
||||||
|
|
||||||
|
merged_df = pd.concat(dataframes, ignore_index=True)
|
||||||
|
|
||||||
# Initialize data loader
|
# Initialize data loader
|
||||||
aiia_loader = AIIADataLoader(
|
aiia_loader = AIIADataLoader(
|
||||||
merged_df,
|
merged_df,
|
||||||
column="image_bytes",
|
column=column,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
pretraining=True,
|
pretraining=True,
|
||||||
collate_fn=self.safe_collate
|
collate_fn=self.safe_collate
|
||||||
|
|
Loading…
Reference in New Issue