finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 75 additions and 70 deletions
Showing only changes of commit 2121316e3b - Show all commits

View File

@ -5,10 +5,9 @@ import io
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms import torchvision.transforms as transforms
from aiia.model import AIIABase from aiia.model import AIIABase, AIIA
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from typing import Dict, List, Union from typing import Dict, List, Union, Optional
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None): def __init__(self, dataframe, transform=None):
@ -36,24 +35,21 @@ class ImageDataset(Dataset):
low_res_image = self.transform(low_res_image) low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image) high_res_image = self.transform(high_res_image)
return {'low_res': low_res_image, 'high_res': high_res_image} return {'low_ress': low_res_image, 'high_ress': high_res_image}
class ModelTrainer: class ModelTrainer:
def __init__(self, def __init__(self,
model_name: str = "AIIA-Base-512", model: AIIA,
dataset_paths: List[str] = None, dataset_paths: List[str],
batch_size: int = 32, batch_size: int = 32,
learning_rate: float = 0.001, learning_rate: float = 0.001,
num_workers: int = 4, num_workers: int = 4,
train_ratio: float = 0.8): train_ratio: float = 0.8):
""" """
Specialized trainer for image super resolution tasks Specialized trainer for image super resolution tasks
Args: Args:
model_name (str): Name of the model to initialize model (nn.Module): Model instance to finetune
dataset_paths (List[str]): Paths to datasets dataset_paths (List[str]): Paths to datasets
batch_size (int): Batch size for training batch_size (int): Batch size for training
learning_rate (float): Learning rate for optimizer learning_rate (float): Learning rate for optimizer
@ -64,120 +60,126 @@ class ModelTrainer:
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
self.dataset_paths = dataset_paths self.dataset_paths = dataset_paths
self.model_name = model_name
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.train_ratio = train_ratio self.train_ratio = train_ratio
self.model = model
# Initialize datasets and loaders # Initialize datasets and loaders
self._initialize_datasets() self._initialize_datasets()
# Initialize model and training parameters # Initialize training parameters
self._initialize_model() self._initialize_training()
def _initialize_datasets(self): def _initialize_datasets(self):
""" """
Helper method to initialize datasets Helper method to initialize datasets
""" """
# Read training data based on input format
if isinstance(self.dataset_paths, list): if isinstance(self.dataset_paths, list):
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True) df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
else: else:
raise ValueError("Invalid dataset_paths format. Must be a list or dictionary.") raise ValueError("Invalid dataset_paths format. Must be a list.")
df_train, df_val = train_test_split( df_train, df_val = train_test_split(
df_train, df_train,
test_size=1 - self.train_ratio, test_size=1 - self.train_ratio,
random_state=42 random_state=42
) )
# Define preprocessing transforms # Define preprocessing transforms
self.transform = transforms.Compose([ self.transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]) ])
# Create datasets and dataloaders # Create datasets and dataloaders
self.train_dataset = ImageDataset(df_train, transform=self.transform) self.train_dataset = ImageDataset(df_train, transform=self.transform)
self.val_dataset = ImageDataset(df_val, transform=self.transform) self.val_dataset = ImageDataset(df_val, transform=self.transform)
self.train_loader = DataLoader( self.train_loader = DataLoader(
self.train_dataset, self.train_dataset,
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=True, shuffle=True,
num_workers=self.num_workers num_workers=self.num_workers
) )
self.val_loader = DataLoader( self.val_loader = DataLoader(
self.val_dataset, self.val_dataset,
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=False, shuffle=False,
num_workers=self.num_workers num_workers=self.num_workers
) if df_val is not None else None ) if df_val is not None else None
def _initialize_model(self): def _initialize_training(self):
""" """
Helper method to initialize model architecture and training parameters Helper method to initialize training parameters
""" """
# Load base model # Freeze CNN layers (if applicable)
self.model = AIIABase.load(self.model_name) try:
for param in self.model.cnn.parameters():
# Freeze CNN layers param.requires_grad = False
for param in self.model.cnn.parameters(): except AttributeError:
param.requires_grad = False pass # If model doesn't have a 'cnn' attribute, just continue
# Add upscaling layer # Add upscaling layer if not already present
hidden_size = self.model.config.hidden_size if not hasattr(self.model, 'upsample'):
kernel_size = self.model.config.kernel_size hidden_size = self.model.config.hidden_size
self.model.upsample = nn.Sequential( kernel_size = self.model.config.kernel_size
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), self.model.upsample = nn.Sequential(
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1) nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
) nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
)
# Initialize optimizer and loss function # Initialize optimizer and loss function
self.criterion = nn.MSELoss() self.criterion = nn.MSELoss()
# Get parameters of the upsample layer for training
params = [p for p in self.model.upsample.parameters() if p.requires_grad]
if not params:
raise ValueError("No parameters found in upsample layer to optimize")
self.optimizer = torch.optim.Adam( self.optimizer = torch.optim.Adam(
[param for param in self.model.parameters() if 'upsample' in str(param)], params,
lr=self.learning_rate lr=self.learning_rate
) )
self.best_val_loss = float('inf') self.best_val_loss = float('inf')
def train(self, num_epochs: int = 10): def train(self, num_epochs: int = 10):
""" """
Train the model for specified number of epochs Train the model for specified number of epochs
""" """
self.model.to(self.device) self.model.to(self.device)
for epoch in range(num_epochs): for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}") print(f"Epoch {epoch+1}/{num_epochs}")
# Train phase # Train phase
self._train_epoch() self._train_epoch()
# Validation phase # Validation phase
if self.val_loader is not None: if self.val_loader is not None:
self._validate_epoch() self._validate_epoch()
# Save best model based on validation loss # Save best model based on validation loss
if self.val_loader is not None and self.current_val_loss < self.best_val_loss: if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
self.model.save("aiuNN-base") self.model.save("aiuNN-finetuned")
def _train_epoch(self): def _train_epoch(self):
""" """
Train model for one epoch Train model for one epoch
""" """
self.model.train() self.model.train()
running_loss = 0.0 running_loss = 0.0
for batch in self.train_loader: for batch in self.train_loader:
low_res = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_res = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
# Forward pass # Forward pass
features = self.model.cnn(low_res) features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
outputs = self.model.upsample(features) outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_res) loss = self.criterion(outputs, high_ress)
# Backward pass and optimize # Backward pass and optimize
self.optimizer.zero_grad() self.optimizer.zero_grad()
@ -185,41 +187,44 @@ class ModelTrainer:
self.optimizer.step() self.optimizer.step()
running_loss += loss.item() running_loss += loss.item()
epoch_loss = running_loss / len(self.train_loader) epoch_loss = running_loss / len(self.train_loader)
print(f"Train Loss: {epoch_loss:.4f}") print(f"Train Loss: {epoch_loss:.4f}")
def _validate_epoch(self): def _validate_epoch(self):
""" """
Validate model performance Validate model performance
""" """
self.model.eval() self.model.eval()
val_oss = 0.0 val_loss = 0.0
with torch.no_grad(): with torch.no_grad():
for batch in self.val_loader: for batch in self.val_loader:
low_res = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_res = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
features = self.model.cnn(low_res) features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
outputs = self.model.upsample(features) outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_res) loss = self.criterion(outputs, high_ress)
val_loss += loss.item() val_loss += loss.item()
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0 avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
print(f"Validation Loss: {avg_val_loss:.4f}") print(f"Validation Loss: {avg_val_loss:.4f}")
# Update best model # Update best model
if avg_val_loss < self.best_val_loss: if avg_val_loss < self.best_val_loss:
self.best_val_loss = avg_val_loss self.best_val_loss = avg_val_loss
def __repr__(self): def __repr__(self):
return f"Model ({self.model_name}, batch_size={self.batch_size})" return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
if __name__ == "__main__": if __name__ == "__main__":
# Load your model first
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
trainer = ModelTrainer( trainer = ModelTrainer(
model_name="/root/vision/AIIA/AIIA-base-512/", model=model,
dataset_paths=[ dataset_paths=[
"/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet" "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
@ -227,5 +232,5 @@ if __name__ == "__main__":
batch_size=2, batch_size=2,
learning_rate=0.001 learning_rate=0.001
) )
trainer.train(num__epochs=3) trainer.train(num_epochs=3)