From 39efdff09ed25eb76c573eda0faebcca15972d81 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Thu, 6 Feb 2025 20:54:05 +0100 Subject: [PATCH] updated datasets --- src/aiunn/finetune.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 23434cd..6ce67b4 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -13,13 +13,12 @@ from torch import nn from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive class aiuNNDataset(torch.utils.data.Dataset): - def __init__(self, parquet_path, config=None): + def __init__(self, parquet_path): # Read the Parquet file self.df = pd.read_parquet(parquet_path) - # Data augmentation pipeline + # Data augmentation pipeline without Resize as it's redundant self.augmentation = Compose([ - Resize((512, 512)), RandomBrightnessContrast(), HorizontalFlip(p=0.5), VerticalFlip(p=0.5), @@ -81,19 +80,25 @@ class aiuNNDataset(torch.utils.data.Dataset): high_res_image = self.load_image(row['image_1024']) # Apply augmentation and normalization - augmented = self.augmentation(image=low_res_image, mask=high_res_image) - low_res = augmented['image'] - high_res = augmented['mask'] + augmented_low = self.augmentation(image=low_res_image) + low_res = augmented_low['image'] + + augmented_high = self.augmentation(image=high_res_image) + high_res = augmented_high['image'] return { 'low_res': low_res, 'high_res': high_res } +from torch.utils.data.dataset import ConcatDataset -def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10): - # Initialize dataset and dataloader - train_dataset = aiuNNDataset(train_parquet_path) - val_dataset = aiuNNDataset(val_parquet_path) +def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10): + # Load all datasets and concatenate them + loaded_datasets = [aiuNNDataset(d) for d in datasets] + combined_dataset = ConcatDataset(loaded_datasets) + + # Split into training and validation sets + train_dataset, val_dataset = combined_dataset.train_val_split() train_loader = torch.utils.data.DataLoader( train_dataset,