finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 15 additions and 10 deletions
Showing only changes of commit 39efdff09e - Show all commits

View File

@ -13,13 +13,12 @@ from torch import nn
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path, config=None):
def __init__(self, parquet_path):
# Read the Parquet file
self.df = pd.read_parquet(parquet_path)
# Data augmentation pipeline
# Data augmentation pipeline without Resize as it's redundant
self.augmentation = Compose([
Resize((512, 512)),
RandomBrightnessContrast(),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
@ -81,19 +80,25 @@ class aiuNNDataset(torch.utils.data.Dataset):
high_res_image = self.load_image(row['image_1024'])
# Apply augmentation and normalization
augmented = self.augmentation(image=low_res_image, mask=high_res_image)
low_res = augmented['image']
high_res = augmented['mask']
augmented_low = self.augmentation(image=low_res_image)
low_res = augmented_low['image']
augmented_high = self.augmentation(image=high_res_image)
high_res = augmented_high['image']
return {
'low_res': low_res,
'high_res': high_res
}
from torch.utils.data.dataset import ConcatDataset
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10):
# Initialize dataset and dataloader
train_dataset = aiuNNDataset(train_parquet_path)
val_dataset = aiuNNDataset(val_parquet_path)
def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
# Load all datasets and concatenate them
loaded_datasets = [aiuNNDataset(d) for d in datasets]
combined_dataset = ConcatDataset(loaded_datasets)
# Split into training and validation sets
train_dataset, val_dataset = combined_dataset.train_val_split()
train_loader = torch.utils.data.DataLoader(
train_dataset,