diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index e616bd0..b0dd634 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -18,7 +18,7 @@ from torch.utils.checkpoint import checkpoint class aiuNNDataset(torch.utils.data.Dataset): def __init__(self, parquet_path): - self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2000) + self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(10000) self.augmentation = Compose([ RandomBrightnessContrast(p=0.5), HorizontalFlip(p=0.5),