finetune_class #1
|
@ -41,83 +41,10 @@ class ImageDataset(Dataset):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingBase:
|
class ModelTrainer:
|
||||||
def __init__(self,
|
|
||||||
model_name: str,
|
|
||||||
dataset_paths: Union[List[str], Dict[str, str]],
|
|
||||||
batch_size: int = 32,
|
|
||||||
learning_rate: float = 0.001,
|
|
||||||
num_workers: int = 4,
|
|
||||||
train_ratio: float = 0.8):
|
|
||||||
"""
|
|
||||||
Base class for training models with multiple dataset support
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): Name of the model to initialize
|
|
||||||
dataset_paths (Union[List[str], Dict[str, str]]): Paths to datasets (train and optional validation)
|
|
||||||
batch_size (int): Batch size for training
|
|
||||||
learning_rate (float): Learning rate for optimizer
|
|
||||||
num_workers (int): Number of workers for data loading
|
|
||||||
train_ratio (float): Ratio of data to use for training (rest goes to validation)
|
|
||||||
"""
|
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.num_workers = num_workers
|
|
||||||
|
|
||||||
# Initialize datasets and loaders
|
|
||||||
self.dataset_paths = dataset_paths
|
|
||||||
self._initialize_datasets()
|
|
||||||
|
|
||||||
# Initialize model and training parameters
|
|
||||||
self.model_name = model_name
|
|
||||||
self.learning_rate = learning_rate
|
|
||||||
self._initialize_model()
|
|
||||||
|
|
||||||
def _initialize_datasets(self):
|
|
||||||
"""Helper method to initialize datasets"""
|
|
||||||
raise NotImplementedError("This method should be implemented in child classes")
|
|
||||||
|
|
||||||
def _initialize_model(self):
|
|
||||||
"""Helper method to initialize model architecture"""
|
|
||||||
raise NotImplementedError("This method should be implemented in child classes")
|
|
||||||
|
|
||||||
def train(self, num_epochs: int = 10):
|
|
||||||
"""Train the model for specified number of epochs"""
|
|
||||||
self.model.to(self.device)
|
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}")
|
|
||||||
|
|
||||||
# Train phase
|
|
||||||
self._train_epoch()
|
|
||||||
|
|
||||||
# Validation phase
|
|
||||||
self._validate_epoch()
|
|
||||||
|
|
||||||
# Save best model based on validation loss
|
|
||||||
if self.current_val_loss < self.best_val_loss:
|
|
||||||
self.save_model()
|
|
||||||
|
|
||||||
def _train_epoch(self):
|
|
||||||
"""Train model for one epoch"""
|
|
||||||
raise NotImplementedError("This method should be implemented in child classes")
|
|
||||||
|
|
||||||
def _validate_epoch(self):
|
|
||||||
"""Validate model performance"""
|
|
||||||
raise NotImplementedError("This method should be implemented in child classes")
|
|
||||||
|
|
||||||
def save_model(self):
|
|
||||||
"""Save current best model"""
|
|
||||||
torch.save({
|
|
||||||
'model_state_dict': self.model.state_dict(),
|
|
||||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
||||||
'best_val_loss': self.best_val_loss
|
|
||||||
}, f"{self.model_name}_best.pth")
|
|
||||||
|
|
||||||
class Finetuner(TrainingBase):
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_name: str = "AIIA-Base-512",
|
model_name: str = "AIIA-Base-512",
|
||||||
dataset_paths: Union[List[str], Dict[str, str]] = None,
|
dataset_paths: List[str] = None,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
num_workers: int = 4,
|
num_workers: int = 4,
|
||||||
|
@ -126,25 +53,42 @@ class Finetuner(TrainingBase):
|
||||||
Specialized trainer for image super resolution tasks
|
Specialized trainer for image super resolution tasks
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
Same as TrainingBase
|
model_name (str): Name of the model to initialize
|
||||||
|
dataset_paths (List[str]): Paths to datasets
|
||||||
|
batch_size (int): Batch size for training
|
||||||
|
learning_rate (float): Learning rate for optimizer
|
||||||
|
num_workers (int): Number of workers for data loading
|
||||||
|
train_ratio (float): Ratio of data to use for training (rest goes to validation)
|
||||||
"""
|
"""
|
||||||
super().__init__(model_name, dataset_paths, batch_size, learning_rate, num_workers, train_ratio)
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.dataset_paths = dataset_paths
|
||||||
|
self.model_name = model_name
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.train_ratio = train_ratio
|
||||||
|
|
||||||
|
# Initialize datasets and loaders
|
||||||
|
self._initialize_datasets()
|
||||||
|
|
||||||
|
# Initialize model and training parameters
|
||||||
|
self._initialize_model()
|
||||||
|
|
||||||
def _initialize_datasets(self):
|
def _initialize_datasets(self):
|
||||||
"""Initialize image datasets"""
|
"""
|
||||||
# Load dataframes from parquet files
|
Helper method to initialize datasets
|
||||||
if isinstance(self.dataset_paths, dict):
|
"""
|
||||||
df_train = pd.read_parquet(self.dataset_paths['train'])
|
# Read training data based on input format
|
||||||
df_val = pd.read_parquet(self.dataset_paths['val']) if 'val' in self.dataset_paths else None
|
if isinstance(self.dataset_paths, list):
|
||||||
elif isinstance(self.dataset_paths, list):
|
|
||||||
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
|
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
|
||||||
df_val = None
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid dataset_paths format")
|
raise ValueError("Invalid dataset_paths format. Must be a list or dictionary.")
|
||||||
|
|
||||||
# Split into train and validation sets if needed
|
df_train, df_val = train_test_split(
|
||||||
if df_val is None:
|
df_train,
|
||||||
df_train, df_val = train_test_split(df_train, test_size=1 - self.train_ratio, random_state=42)
|
test_size=1 - self.train_ratio,
|
||||||
|
random_state=42
|
||||||
|
)
|
||||||
|
|
||||||
# Define preprocessing transforms
|
# Define preprocessing transforms
|
||||||
self.transform = transforms.Compose([
|
self.transform = transforms.Compose([
|
||||||
|
@ -168,10 +112,12 @@ class Finetuner(TrainingBase):
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=self.num_workers
|
num_workers=self.num_workers
|
||||||
)
|
) if df_val is not None else None
|
||||||
|
|
||||||
def _initialize_model(self):
|
def _initialize_model(self):
|
||||||
"""Initialize and modify the super resolution model"""
|
"""
|
||||||
|
Helper method to initialize model architecture and training parameters
|
||||||
|
"""
|
||||||
# Load base model
|
# Load base model
|
||||||
self.model = AIIABase.load(self.model_name)
|
self.model = AIIABase.load(self.model_name)
|
||||||
|
|
||||||
|
@ -181,9 +127,10 @@ class Finetuner(TrainingBase):
|
||||||
|
|
||||||
# Add upscaling layer
|
# Add upscaling layer
|
||||||
hidden_size = self.model.config.hidden_size
|
hidden_size = self.model.config.hidden_size
|
||||||
|
kernel_size = self.model.config.kernel_size
|
||||||
self.model.upsample = nn.Sequential(
|
self.model.upsample = nn.Sequential(
|
||||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
||||||
nn.Conv2d(hidden_size, 3, kernel_size=3, padding=1)
|
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize optimizer and loss function
|
# Initialize optimizer and loss function
|
||||||
|
@ -195,14 +142,36 @@ class Finetuner(TrainingBase):
|
||||||
|
|
||||||
self.best_val_loss = float('inf')
|
self.best_val_loss = float('inf')
|
||||||
|
|
||||||
|
def train(self, num_epochs: int = 10):
|
||||||
|
"""
|
||||||
|
Train the model for specified number of epochs
|
||||||
|
"""
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f"Epoch {epoch+1}/{num_epochs}")
|
||||||
|
|
||||||
|
# Train phase
|
||||||
|
self._train_epoch()
|
||||||
|
|
||||||
|
# Validation phase
|
||||||
|
if self.val_loader is not None:
|
||||||
|
self._validate_epoch()
|
||||||
|
|
||||||
|
# Save best model based on validation loss
|
||||||
|
if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
|
||||||
|
self.model.save("aiuNN-base")
|
||||||
|
|
||||||
def _train_epoch(self):
|
def _train_epoch(self):
|
||||||
"""Train model for one epoch"""
|
"""
|
||||||
|
Train model for one epoch
|
||||||
|
"""
|
||||||
self.model.train()
|
self.model.train()
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
|
||||||
for batch in self.train_loader:
|
for batch in self.train_loader:
|
||||||
low_res = batch['low_res'].to(self.device)
|
low_res = batch['low_ress'].to(self.device)
|
||||||
high_res = batch['high_res'].to(self.device)
|
high_res = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
features = self.model.cnn(low_res)
|
features = self.model.cnn(low_res)
|
||||||
|
@ -221,14 +190,16 @@ class Finetuner(TrainingBase):
|
||||||
print(f"Train Loss: {epoch_loss:.4f}")
|
print(f"Train Loss: {epoch_loss:.4f}")
|
||||||
|
|
||||||
def _validate_epoch(self):
|
def _validate_epoch(self):
|
||||||
"""Validate model performance"""
|
"""
|
||||||
|
Validate model performance
|
||||||
|
"""
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
val_loss = 0.0
|
val_oss = 0.0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in self.val_loader:
|
for batch in self.val_loader:
|
||||||
low_res = batch['low_res'].to(self.device)
|
low_res = batch['low_ress'].to(self.device)
|
||||||
high_res = batch['high_res'].to(self.device)
|
high_res = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
features = self.model.cnn(low_res)
|
features = self.model.cnn(low_res)
|
||||||
outputs = self.model.upsample(features)
|
outputs = self.model.upsample(features)
|
||||||
|
@ -236,24 +207,25 @@ class Finetuner(TrainingBase):
|
||||||
loss = self.criterion(outputs, high_res)
|
loss = self.criterion(outputs, high_res)
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(self.val_loader)
|
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
|
||||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||||
|
|
||||||
# Update best model
|
# Update best model
|
||||||
if avg_val_loss < self.best_val_loss:
|
if avg_val_loss < self.best_val_loss:
|
||||||
self.best_val_loss = avg_val_loss
|
self.best_val_loss = avg_val_loss
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Model ({self.model_name}, batch_size={self.batch_size})"
|
return f"Model ({self.model_name}, batch_size={self.batch_size})"
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
finetuner = Finetuner(
|
trainer = ModelTrainer(
|
||||||
train_parquet_path="/root/training_data/vision-dataset/image_upscaler.parquet",
|
model_name="/root/vision/AIIA/AIIA-base-512/",
|
||||||
val_parquet_path="/root/training_data/vision-dataset/image_vec_upscaler.parquet",
|
dataset_paths=[
|
||||||
|
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||||
|
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||||
|
],
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
learning_rate=0.001
|
learning_rate=0.001
|
||||||
)
|
)
|
||||||
|
|
||||||
finetuner.train_model(num_epochs=10)
|
trainer.train(num__epochs=3)
|
Loading…
Reference in New Issue