From 7de7eef0810bd134819335d3938da04fcbcba53d Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 28 Jan 2025 11:16:09 +0100 Subject: [PATCH] updated pretraing to create a extra class for Pretraining --- README.md | 17 +++ src/aiia/__init__.py | 2 + src/aiia/pretrain/__init__.py | 3 + src/aiia/pretrain/pretrainer.py | 219 +++++++++++++++++++++++++++++++ src/pretrain.py | 226 -------------------------------- 5 files changed, 241 insertions(+), 226 deletions(-) create mode 100644 src/aiia/pretrain/__init__.py create mode 100644 src/aiia/pretrain/pretrainer.py delete mode 100644 src/pretrain.py diff --git a/README.md b/README.md index 0d888a0..830f111 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,19 @@ # AIIA + +## Example Usage: +```Python +if __name__ == "__main__": + data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet" + data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet" + + from aiia.model import AIIABase + from aiia.model.config import AIIAConfig + from aiia.pretrain import Pretrainer + + config = AIIAConfig(model_name="AIIA-Base-512x20k") + model = AIIABase(config) + + pretrainer = Pretrainer(model, learning_rate=1e-4) + pretrainer.train(data_path1, data_path2, num_epochs=10) +``` \ No newline at end of file diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index 6dbc27a..6a27146 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -1,5 +1,7 @@ from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAExpert, AIIAmoe, AIIA, AIIArecursive from .model.config import AIIAConfig from .data.DataLoader import DataLoader +from .pretrain.pretrainer import Pretrainer, ProjectionHead + __version__ = "0.1.0" diff --git a/src/aiia/pretrain/__init__.py b/src/aiia/pretrain/__init__.py new file mode 100644 index 0000000..c45cbc4 --- /dev/null +++ b/src/aiia/pretrain/__init__.py @@ -0,0 +1,3 @@ +from .pretrainer import Pretrainer, ProjectionHead + +__all__ = ["Pretrainer", "ProjectionHead"] \ No newline at end of file diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py new file mode 100644 index 0000000..b540db0 --- /dev/null +++ b/src/aiia/pretrain/pretrainer.py @@ -0,0 +1,219 @@ +import torch +from torch import nn +import csv +import pandas as pd +from tqdm import tqdm +from ..model.Model import AIIA +from ..data.DataLoader import AIIADataLoader + +class ProjectionHead(nn.Module): + def __init__(self): + super().__init__() + self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1) + self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees + + def forward(self, x, task='denoise'): + if task == 'denoise': + return self.conv_denoise(x) + else: + return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task + +class Pretrainer: + def __init__(self, model: AIIA, learning_rate=1e-4): + """ + Initialize the pretrainer with a model. + + Args: + model (AIIA): The model instance to pretrain + learning_rate (float): Learning rate for optimization + """ + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = model.to(self.device) + self.projection_head = ProjectionHead().to(self.device) + self.optimizer = torch.optim.AdamW( + list(self.model.parameters()) + list(self.projection_head.parameters()), + lr=learning_rate + ) + self.train_losses = [] + self.val_losses = [] + + @staticmethod + def safe_collate(batch): + """Safely collate batch data handling both denoise and rotate tasks.""" + denoise_batch = [] + rotate_batch = [] + + for sample in batch: + try: + noisy_img, target, task = sample + if task == 'denoise': + denoise_batch.append({ + 'image': noisy_img, + 'target': target, + 'task': task + }) + else: # rotate task + rotate_batch.append({ + 'image': noisy_img, + 'target': target, + 'task': task + }) + except Exception as e: + print(f"Skipping sample due to error: {e}") + continue + + if not denoise_batch and not rotate_batch: + return None + + batch_data = { + 'denoise': None, + 'rotate': None + } + + if denoise_batch: + images = torch.stack([x['image'] for x in denoise_batch]) + targets = torch.stack([x['target'] for x in denoise_batch]) + batch_data['denoise'] = (images, targets) + + if rotate_batch: + images = torch.stack([x['image'] for x in rotate_batch]) + targets = torch.stack([x['target'] for x in rotate_batch]) + batch_data['rotate'] = (images, targets) + + return batch_data + + def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True): + """Process a single batch of data.""" + batch_loss = 0 + + if batch_data['denoise'] is not None: + noisy_imgs, targets = batch_data['denoise'] + noisy_imgs = noisy_imgs.to(self.device) + targets = targets.to(self.device) + + features = self.model(noisy_imgs) + outputs = self.projection_head(features, task='denoise') + loss = criterion_denoise(outputs, targets) + batch_loss += loss + + if batch_data['rotate'] is not None: + imgs, targets = batch_data['rotate'] + imgs = imgs.to(self.device) + targets = targets.long().to(self.device) + + features = self.model(imgs) + outputs = self.projection_head(features, task='rotate') + loss = criterion_rotate(outputs, targets) + batch_loss += loss + + return batch_loss + + def train(self, data_path1, data_path2, num_epochs=3, batch_size=2, sample_size=10000): + """ + Train the model using the specified datasets. + + Args: + data_path1 (str): Path to first dataset + data_path2 (str): Path to second dataset + num_epochs (int): Number of training epochs + batch_size (int): Batch size for training + sample_size (int): Number of samples to use from each dataset + """ + # Read and merge datasets + df1 = pd.read_parquet(data_path1).head(sample_size) + df2 = pd.read_parquet(data_path2).head(sample_size) + merged_df = pd.concat([df1, df2], ignore_index=True) + + # Initialize data loader + aiia_loader = AIIADataLoader( + merged_df, + column="image_bytes", + batch_size=batch_size, + pretraining=True, + collate_fn=self.safe_collate + ) + + criterion_denoise = nn.MSELoss() + criterion_rotate = nn.CrossEntropyLoss() + best_val_loss = float('inf') + + for epoch in range(num_epochs): + print(f"\nEpoch {epoch+1}/{num_epochs}") + print("-" * 20) + + # Training phase + self.model.train() + self.projection_head.train() + total_train_loss = 0.0 + batch_count = 0 + + for batch_data in tqdm(aiia_loader.train_loader): + if batch_data is None: + continue + + self.optimizer.zero_grad() + batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) + + if batch_loss > 0: + batch_loss.backward() + self.optimizer.step() + total_train_loss += batch_loss.item() + batch_count += 1 + + avg_train_loss = total_train_loss / max(batch_count, 1) + self.train_losses.append(avg_train_loss) + print(f"Training Loss: {avg_train_loss:.4f}") + + # Validation phase + self.model.eval() + self.projection_head.eval() + val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) + + if val_loss < best_val_loss: + best_val_loss = val_loss + self.save_model("AIIA-base-512") + print("Best model saved!") + + self.save_losses('losses.csv') + + def _validate(self, val_loader, criterion_denoise, criterion_rotate): + """Perform validation and return average validation loss.""" + val_loss = 0.0 + val_batch_count = 0 + + with torch.no_grad(): + for batch_data in val_loader: + if batch_data is None: + continue + + batch_loss = self._process_batch( + batch_data, criterion_denoise, criterion_rotate, training=False + ) + + if batch_loss > 0: + val_loss += batch_loss.item() + val_batch_count += 1 + + avg_val_loss = val_loss / max(val_batch_count, 1) + self.val_losses.append(avg_val_loss) + print(f"Validation Loss: {avg_val_loss:.4f}") + return avg_val_loss + + def save_model(self, path): + """Save the model and projection head.""" + self.model.save(path) + torch.save(self.projection_head.state_dict(), f"{path}_projection_head.pth") + + def save_losses(self, csv_file): + """Save training and validation losses to a CSV file.""" + data = list(zip( + range(1, len(self.train_losses) + 1), + self.train_losses, + self.val_losses + )) + + with open(csv_file, mode='w', newline='') as file: + writer = csv.writer(file) + writer.writerow(['Epoch', 'Train Loss', 'Validation Loss']) + writer.writerows(data) + print(f"Loss data has been written to {csv_file}") \ No newline at end of file diff --git a/src/pretrain.py b/src/pretrain.py deleted file mode 100644 index 02e4e7f..0000000 --- a/src/pretrain.py +++ /dev/null @@ -1,226 +0,0 @@ -import torch -from torch import nn -import csv -import pandas as pd -from aiia.model.config import AIIAConfig -from aiia.model import AIIABase -from aiia.data.DataLoader import AIIADataLoader -from tqdm import tqdm - -class ProjectionHead(nn.Module): - def __init__(self): - super().__init__() - self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1) - self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees - - def forward(self, x, task='denoise'): - if task == 'denoise': - return self.conv_denoise(x) - else: - return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task - -def pretrain_model(data_path1, data_path2, num_epochs=3): - # Read and merge datasets - df1 = pd.read_parquet(data_path1).head(10000) - df2 = pd.read_parquet(data_path2).head(10000) - merged_df = pd.concat([df1, df2], ignore_index=True) - - # Model configuration - config = AIIAConfig( - model_name="AIIA-Base-512x20k", - ) - - # Initialize model and projection head - model = AIIABase(config) - projection_head = ProjectionHead() - - device = "cuda" if torch.cuda.is_available() else "cpu" - model.to(device) - projection_head.to(device) - - def safe_collate(batch): - denoise_batch = [] - rotate_batch = [] - - for sample in batch: - try: - noisy_img, target, task = sample - if task == 'denoise': - denoise_batch.append({ - 'image': noisy_img, - 'target': target, - 'task': task - }) - else: # rotate task - rotate_batch.append({ - 'image': noisy_img, - 'target': target, - 'task': task - }) - except Exception as e: - print(f"Skipping sample due to error: {e}") - continue - - if not denoise_batch and not rotate_batch: - return None - - batch_data = { - 'denoise': None, - 'rotate': None - } - - if denoise_batch: - images = torch.stack([x['image'] for x in denoise_batch]) - targets = torch.stack([x['target'] for x in denoise_batch]) - batch_data['denoise'] = (images, targets) - - if rotate_batch: - images = torch.stack([x['image'] for x in rotate_batch]) - targets = torch.stack([x['target'] for x in rotate_batch]) - batch_data['rotate'] = (images, targets) - - return batch_data - - aiia_loader = AIIADataLoader( - merged_df, - column="image_bytes", - batch_size=2, - pretraining=True, - collate_fn=safe_collate - ) - - train_loader = aiia_loader.train_loader - val_loader = aiia_loader.val_loader - - criterion_denoise = nn.MSELoss() - criterion_rotate = nn.CrossEntropyLoss() - - # Update optimizer to include projection head parameters - optimizer = torch.optim.AdamW( - list(model.parameters()) + list(projection_head.parameters()), - lr=config.learning_rate - ) - - best_val_loss = float('inf') - train_losses = [] - val_losses = [] - for epoch in range(num_epochs): - print(f"\nEpoch {epoch+1}/{num_epochs}") - print("-" * 20) - - # Training phase - model.train() - projection_head.train() - total_train_loss = 0.0 - batch_count = 0 - - for batch_data in tqdm(train_loader): - if batch_data is None: - continue - - optimizer.zero_grad() - batch_loss = 0 - - # Handle denoise task - if batch_data['denoise'] is not None: - noisy_imgs, targets = batch_data['denoise'] - noisy_imgs = noisy_imgs.to(device) - targets = targets.to(device) - - # Get features from base model - features = model(noisy_imgs) - # Project features back to image space - outputs = projection_head(features, task='denoise') - loss = criterion_denoise(outputs, targets) - batch_loss += loss - - # Handle rotate task - if batch_data['rotate'] is not None: - imgs, targets = batch_data['rotate'] - imgs = imgs.to(device) - targets = targets.long().to(device) - - # Get features from base model - features = model(imgs) - # Project features to rotation predictions - outputs = projection_head(features, task='rotate') - - loss = criterion_rotate(outputs, targets) - batch_loss += loss - - if batch_loss > 0: - batch_loss.backward() - optimizer.step() - total_train_loss += batch_loss.item() - batch_count += 1 - - avg_train_loss = total_train_loss / max(batch_count, 1) - train_losses.append(avg_train_loss) - print(f"Training Loss: {avg_train_loss:.4f}") - - # Validation phase - model.eval() - projection_head.eval() - val_loss = 0.0 - val_batch_count = 0 - - with torch.no_grad(): - for batch_data in val_loader: - if batch_data is None: - continue - - batch_loss = 0 - - if batch_data['denoise'] is not None: - noisy_imgs, targets = batch_data['denoise'] - noisy_imgs = noisy_imgs.to(device) - targets = targets.to(device) - - features = model(noisy_imgs) - outputs = projection_head(features, task='denoise') - loss = criterion_denoise(outputs, targets) - batch_loss += loss - - if batch_data['rotate'] is not None: - imgs, targets = batch_data['rotate'] - imgs = imgs.to(device) - targets = targets.long().to(device) - - features = model(imgs) - outputs = projection_head(features, task='rotate') - loss = criterion_rotate(outputs, targets) - batch_loss += loss - - if batch_loss > 0: - val_loss += batch_loss.item() - val_batch_count += 1 - - avg_val_loss = val_loss / max(val_batch_count, 1) - val_losses.append(avg_val_loss) - print(f"Validation Loss: {avg_val_loss:.4f}") - - if avg_val_loss < best_val_loss: - best_val_loss = avg_val_loss - # Save both model and projection head - model.save("AIIA-base-512") - print("Best model saved!") - - # Prepare the data to be written to the CSV file - data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses)) - - # Specify the CSV file name - csv_file = 'losses.csv' - - # Write the data to the CSV file - with open(csv_file, mode='w', newline='') as file: - writer = csv.writer(file) - # Write the header - writer.writerow(['Epoch', 'Train Loss', 'Validation Loss']) - # Write the data - writer.writerows(data) - print(f"Data has been written to {csv_file}") - -if __name__ == "__main__": - data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet" - data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet" - pretrain_model(data_path1, data_path2, num_epochs=10) \ No newline at end of file