updated pretraing to create a extra class for Pretraining

This commit is contained in:
Falko Victor Habel 2025-01-28 11:16:09 +01:00
parent 1e665c4604
commit 7de7eef081
5 changed files with 241 additions and 226 deletions

View File

@ -1,2 +1,19 @@
# AIIA
## Example Usage:
```Python
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
from aiia.model import AIIABase
from aiia.model.config import AIIAConfig
from aiia.pretrain import Pretrainer
config = AIIAConfig(model_name="AIIA-Base-512x20k")
model = AIIABase(config)
pretrainer = Pretrainer(model, learning_rate=1e-4)
pretrainer.train(data_path1, data_path2, num_epochs=10)
```

View File

@ -1,5 +1,7 @@
from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAExpert, AIIAmoe, AIIA, AIIArecursive
from .model.config import AIIAConfig
from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.1.0"

View File

@ -0,0 +1,3 @@
from .pretrainer import Pretrainer, ProjectionHead
__all__ = ["Pretrainer", "ProjectionHead"]

View File

@ -0,0 +1,219 @@
import torch
from torch import nn
import csv
import pandas as pd
from tqdm import tqdm
from ..model.Model import AIIA
from ..data.DataLoader import AIIADataLoader
class ProjectionHead(nn.Module):
def __init__(self):
super().__init__()
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4):
"""
Initialize the pretrainer with a model.
Args:
model (AIIA): The model instance to pretrain
learning_rate (float): Learning rate for optimization
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = model.to(self.device)
self.projection_head = ProjectionHead().to(self.device)
self.optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(self.projection_head.parameters()),
lr=learning_rate
)
self.train_losses = []
self.val_losses = []
@staticmethod
def safe_collate(batch):
"""Safely collate batch data handling both denoise and rotate tasks."""
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True):
"""Process a single batch of data."""
batch_loss = 0
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(self.device)
targets = targets.to(self.device)
features = self.model(noisy_imgs)
outputs = self.projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(self.device)
targets = targets.long().to(self.device)
features = self.model(imgs)
outputs = self.projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
return batch_loss
def train(self, data_path1, data_path2, num_epochs=3, batch_size=2, sample_size=10000):
"""
Train the model using the specified datasets.
Args:
data_path1 (str): Path to first dataset
data_path2 (str): Path to second dataset
num_epochs (int): Number of training epochs
batch_size (int): Batch size for training
sample_size (int): Number of samples to use from each dataset
"""
# Read and merge datasets
df1 = pd.read_parquet(data_path1).head(sample_size)
df2 = pd.read_parquet(data_path2).head(sample_size)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Initialize data loader
aiia_loader = AIIADataLoader(
merged_df,
column="image_bytes",
batch_size=batch_size,
pretraining=True,
collate_fn=self.safe_collate
)
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(aiia_loader.train_loader):
if batch_data is None:
continue
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0:
batch_loss.backward()
self.optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
self.model.eval()
self.projection_head.eval()
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.save_model("AIIA-base-512")
print("Best model saved!")
self.save_losses('losses.csv')
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss."""
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
for batch_data in val_loader:
if batch_data is None:
continue
batch_loss = self._process_batch(
batch_data, criterion_denoise, criterion_rotate, training=False
)
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
self.val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
return avg_val_loss
def save_model(self, path):
"""Save the model and projection head."""
self.model.save(path)
torch.save(self.projection_head.state_dict(), f"{path}_projection_head.pth")
def save_losses(self, csv_file):
"""Save training and validation losses to a CSV file."""
data = list(zip(
range(1, len(self.train_losses) + 1),
self.train_losses,
self.val_losses
))
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
writer.writerows(data)
print(f"Loss data has been written to {csv_file}")

View File

@ -1,226 +0,0 @@
import torch
from torch import nn
import csv
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader
from tqdm import tqdm
class ProjectionHead(nn.Module):
def __init__(self):
super().__init__()
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
def pretrain_model(data_path1, data_path2, num_epochs=3):
# Read and merge datasets
df1 = pd.read_parquet(data_path1).head(10000)
df2 = pd.read_parquet(data_path2).head(10000)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Model configuration
config = AIIAConfig(
model_name="AIIA-Base-512x20k",
)
# Initialize model and projection head
model = AIIABase(config)
projection_head = ProjectionHead()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
projection_head.to(device)
def safe_collate(batch):
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
aiia_loader = AIIADataLoader(
merged_df,
column="image_bytes",
batch_size=2,
pretraining=True,
collate_fn=safe_collate
)
train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
# Update optimizer to include projection head parameters
optimizer = torch.optim.AdamW(
list(model.parameters()) + list(projection_head.parameters()),
lr=config.learning_rate
)
best_val_loss = float('inf')
train_losses = []
val_losses = []
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
model.train()
projection_head.train()
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(train_loader):
if batch_data is None:
continue
optimizer.zero_grad()
batch_loss = 0
# Handle denoise task
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
# Get features from base model
features = model(noisy_imgs)
# Project features back to image space
outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
# Handle rotate task
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
# Get features from base model
features = model(imgs)
# Project features to rotation predictions
outputs = projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
batch_loss.backward()
optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
model.eval()
projection_head.eval()
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
for batch_data in val_loader:
if batch_data is None:
continue
batch_loss = 0
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
features = model(noisy_imgs)
outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
features = model(imgs)
outputs = projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
# Save both model and projection head
model.save("AIIA-base-512")
print("Best model saved!")
# Prepare the data to be written to the CSV file
data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses))
# Specify the CSV file name
csv_file = 'losses.csv'
# Write the data to the CSV file
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
# Write the header
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
# Write the data
writer.writerows(data)
print(f"Data has been written to {csv_file}")
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=10)