finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 75 additions and 70 deletions
Showing only changes of commit 2121316e3b - Show all commits

View File

@ -5,10 +5,9 @@ import io
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms import torchvision.transforms as transforms
from aiia.model import AIIABase from aiia.model import AIIABase, AIIA
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from typing import Dict, List, Union from typing import Dict, List, Union, Optional
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None): def __init__(self, dataframe, transform=None):
@ -36,15 +35,12 @@ class ImageDataset(Dataset):
low_res_image = self.transform(low_res_image) low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image) high_res_image = self.transform(high_res_image)
return {'low_res': low_res_image, 'high_res': high_res_image} return {'low_ress': low_res_image, 'high_ress': high_res_image}
class ModelTrainer: class ModelTrainer:
def __init__(self, def __init__(self,
model_name: str = "AIIA-Base-512", model: AIIA,
dataset_paths: List[str] = None, dataset_paths: List[str],
batch_size: int = 32, batch_size: int = 32,
learning_rate: float = 0.001, learning_rate: float = 0.001,
num_workers: int = 4, num_workers: int = 4,
@ -53,7 +49,7 @@ class ModelTrainer:
Specialized trainer for image super resolution tasks Specialized trainer for image super resolution tasks
Args: Args:
model_name (str): Name of the model to initialize model (nn.Module): Model instance to finetune
dataset_paths (List[str]): Paths to datasets dataset_paths (List[str]): Paths to datasets
batch_size (int): Batch size for training batch_size (int): Batch size for training
learning_rate (float): Learning rate for optimizer learning_rate (float): Learning rate for optimizer
@ -64,25 +60,24 @@ class ModelTrainer:
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
self.dataset_paths = dataset_paths self.dataset_paths = dataset_paths
self.model_name = model_name
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.train_ratio = train_ratio self.train_ratio = train_ratio
self.model = model
# Initialize datasets and loaders # Initialize datasets and loaders
self._initialize_datasets() self._initialize_datasets()
# Initialize model and training parameters # Initialize training parameters
self._initialize_model() self._initialize_training()
def _initialize_datasets(self): def _initialize_datasets(self):
""" """
Helper method to initialize datasets Helper method to initialize datasets
""" """
# Read training data based on input format
if isinstance(self.dataset_paths, list): if isinstance(self.dataset_paths, list):
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True) df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
else: else:
raise ValueError("Invalid dataset_paths format. Must be a list or dictionary.") raise ValueError("Invalid dataset_paths format. Must be a list.")
df_train, df_val = train_test_split( df_train, df_val = train_test_split(
df_train, df_train,
@ -114,18 +109,19 @@ class ModelTrainer:
num_workers=self.num_workers num_workers=self.num_workers
) if df_val is not None else None ) if df_val is not None else None
def _initialize_model(self): def _initialize_training(self):
""" """
Helper method to initialize model architecture and training parameters Helper method to initialize training parameters
""" """
# Load base model # Freeze CNN layers (if applicable)
self.model = AIIABase.load(self.model_name) try:
# Freeze CNN layers
for param in self.model.cnn.parameters(): for param in self.model.cnn.parameters():
param.requires_grad = False param.requires_grad = False
except AttributeError:
pass # If model doesn't have a 'cnn' attribute, just continue
# Add upscaling layer # Add upscaling layer if not already present
if not hasattr(self.model, 'upsample'):
hidden_size = self.model.config.hidden_size hidden_size = self.model.config.hidden_size
kernel_size = self.model.config.kernel_size kernel_size = self.model.config.kernel_size
self.model.upsample = nn.Sequential( self.model.upsample = nn.Sequential(
@ -135,8 +131,14 @@ class ModelTrainer:
# Initialize optimizer and loss function # Initialize optimizer and loss function
self.criterion = nn.MSELoss() self.criterion = nn.MSELoss()
# Get parameters of the upsample layer for training
params = [p for p in self.model.upsample.parameters() if p.requires_grad]
if not params:
raise ValueError("No parameters found in upsample layer to optimize")
self.optimizer = torch.optim.Adam( self.optimizer = torch.optim.Adam(
[param for param in self.model.parameters() if 'upsample' in str(param)], params,
lr=self.learning_rate lr=self.learning_rate
) )
@ -160,7 +162,7 @@ class ModelTrainer:
# Save best model based on validation loss # Save best model based on validation loss
if self.val_loader is not None and self.current_val_loss < self.best_val_loss: if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
self.model.save("aiuNN-base") self.model.save("aiuNN-finetuned")
def _train_epoch(self): def _train_epoch(self):
""" """
@ -170,14 +172,14 @@ class ModelTrainer:
running_loss = 0.0 running_loss = 0.0
for batch in self.train_loader: for batch in self.train_loader:
low_res = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_res = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
# Forward pass # Forward pass
features = self.model.cnn(low_res) features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
outputs = self.model.upsample(features) outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_res) loss = self.criterion(outputs, high_ress)
# Backward pass and optimize # Backward pass and optimize
self.optimizer.zero_grad() self.optimizer.zero_grad()
@ -194,17 +196,17 @@ class ModelTrainer:
Validate model performance Validate model performance
""" """
self.model.eval() self.model.eval()
val_oss = 0.0 val_loss = 0.0
with torch.no_grad(): with torch.no_grad():
for batch in self.val_loader: for batch in self.val_loader:
low_res = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_res = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
features = self.model.cnn(low_res) features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
outputs = self.model.upsample(features) outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_res) loss = self.criterion(outputs, high_ress)
val_loss += loss.item() val_loss += loss.item()
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0 avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
@ -215,11 +217,14 @@ class ModelTrainer:
self.best_val_loss = avg_val_loss self.best_val_loss = avg_val_loss
def __repr__(self): def __repr__(self):
return f"Model ({self.model_name}, batch_size={self.batch_size})" return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
if __name__ == "__main__": if __name__ == "__main__":
# Load your model first
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
trainer = ModelTrainer( trainer = ModelTrainer(
model_name="/root/vision/AIIA/AIIA-base-512/", model=model,
dataset_paths=[ dataset_paths=[
"/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet" "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
@ -228,4 +233,4 @@ if __name__ == "__main__":
learning_rate=0.001 learning_rate=0.001
) )
trainer.train(num__epochs=3) trainer.train(num_epochs=3)