updated pretraing to create a extra class for Pretraining

This commit is contained in:
Falko Victor Habel 2025-01-28 11:16:09 +01:00
parent 1e665c4604
commit 7de7eef081
5 changed files with 241 additions and 226 deletions

View File

@ -1,2 +1,19 @@
# AIIA # AIIA
## Example Usage:
```Python
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
from aiia.model import AIIABase
from aiia.model.config import AIIAConfig
from aiia.pretrain import Pretrainer
config = AIIAConfig(model_name="AIIA-Base-512x20k")
model = AIIABase(config)
pretrainer = Pretrainer(model, learning_rate=1e-4)
pretrainer.train(data_path1, data_path2, num_epochs=10)
```

View File

@ -1,5 +1,7 @@
from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAExpert, AIIAmoe, AIIA, AIIArecursive from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAExpert, AIIAmoe, AIIA, AIIArecursive
from .model.config import AIIAConfig from .model.config import AIIAConfig
from .data.DataLoader import DataLoader from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.1.0" __version__ = "0.1.0"

View File

@ -0,0 +1,3 @@
from .pretrainer import Pretrainer, ProjectionHead
__all__ = ["Pretrainer", "ProjectionHead"]

View File

@ -0,0 +1,219 @@
import torch
from torch import nn
import csv
import pandas as pd
from tqdm import tqdm
from ..model.Model import AIIA
from ..data.DataLoader import AIIADataLoader
class ProjectionHead(nn.Module):
def __init__(self):
super().__init__()
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4):
"""
Initialize the pretrainer with a model.
Args:
model (AIIA): The model instance to pretrain
learning_rate (float): Learning rate for optimization
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = model.to(self.device)
self.projection_head = ProjectionHead().to(self.device)
self.optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(self.projection_head.parameters()),
lr=learning_rate
)
self.train_losses = []
self.val_losses = []
@staticmethod
def safe_collate(batch):
"""Safely collate batch data handling both denoise and rotate tasks."""
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True):
"""Process a single batch of data."""
batch_loss = 0
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(self.device)
targets = targets.to(self.device)
features = self.model(noisy_imgs)
outputs = self.projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(self.device)
targets = targets.long().to(self.device)
features = self.model(imgs)
outputs = self.projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
return batch_loss
def train(self, data_path1, data_path2, num_epochs=3, batch_size=2, sample_size=10000):
"""
Train the model using the specified datasets.
Args:
data_path1 (str): Path to first dataset
data_path2 (str): Path to second dataset
num_epochs (int): Number of training epochs
batch_size (int): Batch size for training
sample_size (int): Number of samples to use from each dataset
"""
# Read and merge datasets
df1 = pd.read_parquet(data_path1).head(sample_size)
df2 = pd.read_parquet(data_path2).head(sample_size)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Initialize data loader
aiia_loader = AIIADataLoader(
merged_df,
column="image_bytes",
batch_size=batch_size,
pretraining=True,
collate_fn=self.safe_collate
)
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(aiia_loader.train_loader):
if batch_data is None:
continue
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0:
batch_loss.backward()
self.optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
self.model.eval()
self.projection_head.eval()
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.save_model("AIIA-base-512")
print("Best model saved!")
self.save_losses('losses.csv')
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss."""
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
for batch_data in val_loader:
if batch_data is None:
continue
batch_loss = self._process_batch(
batch_data, criterion_denoise, criterion_rotate, training=False
)
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
self.val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
return avg_val_loss
def save_model(self, path):
"""Save the model and projection head."""
self.model.save(path)
torch.save(self.projection_head.state_dict(), f"{path}_projection_head.pth")
def save_losses(self, csv_file):
"""Save training and validation losses to a CSV file."""
data = list(zip(
range(1, len(self.train_losses) + 1),
self.train_losses,
self.val_losses
))
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
writer.writerows(data)
print(f"Loss data has been written to {csv_file}")

View File

@ -1,226 +0,0 @@
import torch
from torch import nn
import csv
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader
from tqdm import tqdm
class ProjectionHead(nn.Module):
def __init__(self):
super().__init__()
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
def pretrain_model(data_path1, data_path2, num_epochs=3):
# Read and merge datasets
df1 = pd.read_parquet(data_path1).head(10000)
df2 = pd.read_parquet(data_path2).head(10000)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Model configuration
config = AIIAConfig(
model_name="AIIA-Base-512x20k",
)
# Initialize model and projection head
model = AIIABase(config)
projection_head = ProjectionHead()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
projection_head.to(device)
def safe_collate(batch):
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
aiia_loader = AIIADataLoader(
merged_df,
column="image_bytes",
batch_size=2,
pretraining=True,
collate_fn=safe_collate
)
train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
# Update optimizer to include projection head parameters
optimizer = torch.optim.AdamW(
list(model.parameters()) + list(projection_head.parameters()),
lr=config.learning_rate
)
best_val_loss = float('inf')
train_losses = []
val_losses = []
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
model.train()
projection_head.train()
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(train_loader):
if batch_data is None:
continue
optimizer.zero_grad()
batch_loss = 0
# Handle denoise task
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
# Get features from base model
features = model(noisy_imgs)
# Project features back to image space
outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
# Handle rotate task
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
# Get features from base model
features = model(imgs)
# Project features to rotation predictions
outputs = projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
batch_loss.backward()
optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
model.eval()
projection_head.eval()
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
for batch_data in val_loader:
if batch_data is None:
continue
batch_loss = 0
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
features = model(noisy_imgs)
outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
features = model(imgs)
outputs = projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
# Save both model and projection head
model.save("AIIA-base-512")
print("Best model saved!")
# Prepare the data to be written to the CSV file
data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses))
# Specify the CSV file name
csv_file = 'losses.csv'
# Write the data to the CSV file
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
# Write the header
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
# Write the data
writer.writerows(data)
print(f"Data has been written to {csv_file}")
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=10)