develop #4
|
@ -13,13 +13,12 @@ from torch import nn
|
||||||
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
||||||
|
|
||||||
class aiuNNDataset(torch.utils.data.Dataset):
|
class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, parquet_path, config=None):
|
def __init__(self, parquet_path):
|
||||||
# Read the Parquet file
|
# Read the Parquet file
|
||||||
self.df = pd.read_parquet(parquet_path)
|
self.df = pd.read_parquet(parquet_path)
|
||||||
|
|
||||||
# Data augmentation pipeline
|
# Data augmentation pipeline without Resize as it's redundant
|
||||||
self.augmentation = Compose([
|
self.augmentation = Compose([
|
||||||
Resize((512, 512)),
|
|
||||||
RandomBrightnessContrast(),
|
RandomBrightnessContrast(),
|
||||||
HorizontalFlip(p=0.5),
|
HorizontalFlip(p=0.5),
|
||||||
VerticalFlip(p=0.5),
|
VerticalFlip(p=0.5),
|
||||||
|
@ -81,19 +80,25 @@ class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
high_res_image = self.load_image(row['image_1024'])
|
high_res_image = self.load_image(row['image_1024'])
|
||||||
|
|
||||||
# Apply augmentation and normalization
|
# Apply augmentation and normalization
|
||||||
augmented = self.augmentation(image=low_res_image, mask=high_res_image)
|
augmented_low = self.augmentation(image=low_res_image)
|
||||||
low_res = augmented['image']
|
low_res = augmented_low['image']
|
||||||
high_res = augmented['mask']
|
|
||||||
|
augmented_high = self.augmentation(image=high_res_image)
|
||||||
|
high_res = augmented_high['image']
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'low_res': low_res,
|
'low_res': low_res,
|
||||||
'high_res': high_res
|
'high_res': high_res
|
||||||
}
|
}
|
||||||
|
from torch.utils.data.dataset import ConcatDataset
|
||||||
|
|
||||||
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10):
|
def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
|
||||||
# Initialize dataset and dataloader
|
# Load all datasets and concatenate them
|
||||||
train_dataset = aiuNNDataset(train_parquet_path)
|
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
||||||
val_dataset = aiuNNDataset(val_parquet_path)
|
combined_dataset = ConcatDataset(loaded_datasets)
|
||||||
|
|
||||||
|
# Split into training and validation sets
|
||||||
|
train_dataset, val_dataset = combined_dataset.train_val_split()
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
|
|
Loading…
Reference in New Issue