develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 15 additions and 10 deletions
Showing only changes of commit 39efdff09e - Show all commits

View File

@ -13,13 +13,12 @@ from torch import nn
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
class aiuNNDataset(torch.utils.data.Dataset): class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path, config=None): def __init__(self, parquet_path):
# Read the Parquet file # Read the Parquet file
self.df = pd.read_parquet(parquet_path) self.df = pd.read_parquet(parquet_path)
# Data augmentation pipeline # Data augmentation pipeline without Resize as it's redundant
self.augmentation = Compose([ self.augmentation = Compose([
Resize((512, 512)),
RandomBrightnessContrast(), RandomBrightnessContrast(),
HorizontalFlip(p=0.5), HorizontalFlip(p=0.5),
VerticalFlip(p=0.5), VerticalFlip(p=0.5),
@ -81,19 +80,25 @@ class aiuNNDataset(torch.utils.data.Dataset):
high_res_image = self.load_image(row['image_1024']) high_res_image = self.load_image(row['image_1024'])
# Apply augmentation and normalization # Apply augmentation and normalization
augmented = self.augmentation(image=low_res_image, mask=high_res_image) augmented_low = self.augmentation(image=low_res_image)
low_res = augmented['image'] low_res = augmented_low['image']
high_res = augmented['mask']
augmented_high = self.augmentation(image=high_res_image)
high_res = augmented_high['image']
return { return {
'low_res': low_res, 'low_res': low_res,
'high_res': high_res 'high_res': high_res
} }
from torch.utils.data.dataset import ConcatDataset
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10): def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
# Initialize dataset and dataloader # Load all datasets and concatenate them
train_dataset = aiuNNDataset(train_parquet_path) loaded_datasets = [aiuNNDataset(d) for d in datasets]
val_dataset = aiuNNDataset(val_parquet_path) combined_dataset = ConcatDataset(loaded_datasets)
# Split into training and validation sets
train_dataset, val_dataset = combined_dataset.train_val_split()
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, train_dataset,