finetune_class #1
|
@ -5,10 +5,9 @@ import io
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from aiia.model import AIIABase
|
from aiia.model import AIIABase, AIIA
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union, Optional
|
||||||
|
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
def __init__(self, dataframe, transform=None):
|
def __init__(self, dataframe, transform=None):
|
||||||
|
@ -36,15 +35,12 @@ class ImageDataset(Dataset):
|
||||||
low_res_image = self.transform(low_res_image)
|
low_res_image = self.transform(low_res_image)
|
||||||
high_res_image = self.transform(high_res_image)
|
high_res_image = self.transform(high_res_image)
|
||||||
|
|
||||||
return {'low_res': low_res_image, 'high_res': high_res_image}
|
return {'low_ress': low_res_image, 'high_ress': high_res_image}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelTrainer:
|
class ModelTrainer:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_name: str = "AIIA-Base-512",
|
model: AIIA,
|
||||||
dataset_paths: List[str] = None,
|
dataset_paths: List[str],
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
num_workers: int = 4,
|
num_workers: int = 4,
|
||||||
|
@ -53,7 +49,7 @@ class ModelTrainer:
|
||||||
Specialized trainer for image super resolution tasks
|
Specialized trainer for image super resolution tasks
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (str): Name of the model to initialize
|
model (nn.Module): Model instance to finetune
|
||||||
dataset_paths (List[str]): Paths to datasets
|
dataset_paths (List[str]): Paths to datasets
|
||||||
batch_size (int): Batch size for training
|
batch_size (int): Batch size for training
|
||||||
learning_rate (float): Learning rate for optimizer
|
learning_rate (float): Learning rate for optimizer
|
||||||
|
@ -64,25 +60,24 @@ class ModelTrainer:
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.dataset_paths = dataset_paths
|
self.dataset_paths = dataset_paths
|
||||||
self.model_name = model_name
|
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.train_ratio = train_ratio
|
self.train_ratio = train_ratio
|
||||||
|
self.model = model
|
||||||
|
|
||||||
# Initialize datasets and loaders
|
# Initialize datasets and loaders
|
||||||
self._initialize_datasets()
|
self._initialize_datasets()
|
||||||
|
|
||||||
# Initialize model and training parameters
|
# Initialize training parameters
|
||||||
self._initialize_model()
|
self._initialize_training()
|
||||||
|
|
||||||
def _initialize_datasets(self):
|
def _initialize_datasets(self):
|
||||||
"""
|
"""
|
||||||
Helper method to initialize datasets
|
Helper method to initialize datasets
|
||||||
"""
|
"""
|
||||||
# Read training data based on input format
|
|
||||||
if isinstance(self.dataset_paths, list):
|
if isinstance(self.dataset_paths, list):
|
||||||
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
|
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid dataset_paths format. Must be a list or dictionary.")
|
raise ValueError("Invalid dataset_paths format. Must be a list.")
|
||||||
|
|
||||||
df_train, df_val = train_test_split(
|
df_train, df_val = train_test_split(
|
||||||
df_train,
|
df_train,
|
||||||
|
@ -114,18 +109,19 @@ class ModelTrainer:
|
||||||
num_workers=self.num_workers
|
num_workers=self.num_workers
|
||||||
) if df_val is not None else None
|
) if df_val is not None else None
|
||||||
|
|
||||||
def _initialize_model(self):
|
def _initialize_training(self):
|
||||||
"""
|
"""
|
||||||
Helper method to initialize model architecture and training parameters
|
Helper method to initialize training parameters
|
||||||
"""
|
"""
|
||||||
# Load base model
|
# Freeze CNN layers (if applicable)
|
||||||
self.model = AIIABase.load(self.model_name)
|
try:
|
||||||
|
|
||||||
# Freeze CNN layers
|
|
||||||
for param in self.model.cnn.parameters():
|
for param in self.model.cnn.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
except AttributeError:
|
||||||
|
pass # If model doesn't have a 'cnn' attribute, just continue
|
||||||
|
|
||||||
# Add upscaling layer
|
# Add upscaling layer if not already present
|
||||||
|
if not hasattr(self.model, 'upsample'):
|
||||||
hidden_size = self.model.config.hidden_size
|
hidden_size = self.model.config.hidden_size
|
||||||
kernel_size = self.model.config.kernel_size
|
kernel_size = self.model.config.kernel_size
|
||||||
self.model.upsample = nn.Sequential(
|
self.model.upsample = nn.Sequential(
|
||||||
|
@ -135,8 +131,14 @@ class ModelTrainer:
|
||||||
|
|
||||||
# Initialize optimizer and loss function
|
# Initialize optimizer and loss function
|
||||||
self.criterion = nn.MSELoss()
|
self.criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
# Get parameters of the upsample layer for training
|
||||||
|
params = [p for p in self.model.upsample.parameters() if p.requires_grad]
|
||||||
|
if not params:
|
||||||
|
raise ValueError("No parameters found in upsample layer to optimize")
|
||||||
|
|
||||||
self.optimizer = torch.optim.Adam(
|
self.optimizer = torch.optim.Adam(
|
||||||
[param for param in self.model.parameters() if 'upsample' in str(param)],
|
params,
|
||||||
lr=self.learning_rate
|
lr=self.learning_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -160,7 +162,7 @@ class ModelTrainer:
|
||||||
|
|
||||||
# Save best model based on validation loss
|
# Save best model based on validation loss
|
||||||
if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
|
if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
|
||||||
self.model.save("aiuNN-base")
|
self.model.save("aiuNN-finetuned")
|
||||||
|
|
||||||
def _train_epoch(self):
|
def _train_epoch(self):
|
||||||
"""
|
"""
|
||||||
|
@ -170,14 +172,14 @@ class ModelTrainer:
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
|
||||||
for batch in self.train_loader:
|
for batch in self.train_loader:
|
||||||
low_res = batch['low_ress'].to(self.device)
|
low_ress = batch['low_ress'].to(self.device)
|
||||||
high_res = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
features = self.model.cnn(low_res)
|
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
|
||||||
outputs = self.model.upsample(features)
|
outputs = self.model.upsample(features)
|
||||||
|
|
||||||
loss = self.criterion(outputs, high_res)
|
loss = self.criterion(outputs, high_ress)
|
||||||
|
|
||||||
# Backward pass and optimize
|
# Backward pass and optimize
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
@ -194,17 +196,17 @@ class ModelTrainer:
|
||||||
Validate model performance
|
Validate model performance
|
||||||
"""
|
"""
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
val_oss = 0.0
|
val_loss = 0.0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in self.val_loader:
|
for batch in self.val_loader:
|
||||||
low_res = batch['low_ress'].to(self.device)
|
low_ress = batch['low_ress'].to(self.device)
|
||||||
high_res = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
features = self.model.cnn(low_res)
|
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
|
||||||
outputs = self.model.upsample(features)
|
outputs = self.model.upsample(features)
|
||||||
|
|
||||||
loss = self.criterion(outputs, high_res)
|
loss = self.criterion(outputs, high_ress)
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
|
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
|
||||||
|
@ -215,11 +217,14 @@ class ModelTrainer:
|
||||||
self.best_val_loss = avg_val_loss
|
self.best_val_loss = avg_val_loss
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Model ({self.model_name}, batch_size={self.batch_size})"
|
return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Load your model first
|
||||||
|
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
|
||||||
|
|
||||||
trainer = ModelTrainer(
|
trainer = ModelTrainer(
|
||||||
model_name="/root/vision/AIIA/AIIA-base-512/",
|
model=model,
|
||||||
dataset_paths=[
|
dataset_paths=[
|
||||||
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||||
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||||
|
@ -228,4 +233,4 @@ if __name__ == "__main__":
|
||||||
learning_rate=0.001
|
learning_rate=0.001
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(num__epochs=3)
|
trainer.train(num_epochs=3)
|
Loading…
Reference in New Issue