finetune_class #1
|
@ -13,13 +13,12 @@ from torch import nn
|
|||
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
||||
|
||||
class aiuNNDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, parquet_path, config=None):
|
||||
def __init__(self, parquet_path):
|
||||
# Read the Parquet file
|
||||
self.df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Data augmentation pipeline
|
||||
# Data augmentation pipeline without Resize as it's redundant
|
||||
self.augmentation = Compose([
|
||||
Resize((512, 512)),
|
||||
RandomBrightnessContrast(),
|
||||
HorizontalFlip(p=0.5),
|
||||
VerticalFlip(p=0.5),
|
||||
|
@ -81,19 +80,25 @@ class aiuNNDataset(torch.utils.data.Dataset):
|
|||
high_res_image = self.load_image(row['image_1024'])
|
||||
|
||||
# Apply augmentation and normalization
|
||||
augmented = self.augmentation(image=low_res_image, mask=high_res_image)
|
||||
low_res = augmented['image']
|
||||
high_res = augmented['mask']
|
||||
augmented_low = self.augmentation(image=low_res_image)
|
||||
low_res = augmented_low['image']
|
||||
|
||||
augmented_high = self.augmentation(image=high_res_image)
|
||||
high_res = augmented_high['image']
|
||||
|
||||
return {
|
||||
'low_res': low_res,
|
||||
'high_res': high_res
|
||||
}
|
||||
from torch.utils.data.dataset import ConcatDataset
|
||||
|
||||
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10):
|
||||
# Initialize dataset and dataloader
|
||||
train_dataset = aiuNNDataset(train_parquet_path)
|
||||
val_dataset = aiuNNDataset(val_parquet_path)
|
||||
def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
|
||||
# Load all datasets and concatenate them
|
||||
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
||||
combined_dataset = ConcatDataset(loaded_datasets)
|
||||
|
||||
# Split into training and validation sets
|
||||
train_dataset, val_dataset = combined_dataset.train_val_split()
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
|
|
Loading…
Reference in New Issue