From 7fd2af6f1242c53af86ca8a61cbc5a1ea43382a3 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 25 Feb 2025 19:47:18 +0100 Subject: [PATCH] improved early stoppping --- example.py | 2 +- src/aiunn/finetune/trainer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/example.py b/example.py index b2c26c8..d03e931 100644 --- a/example.py +++ b/example.py @@ -11,7 +11,7 @@ from torchvision import transforms class UpscaleDataset(Dataset): - def __init__(self, parquet_files: list, transform=None, samples_per_file=2500): + def __init__(self, parquet_files: list, transform=None, samples_per_file=5000): combined_df = pd.DataFrame() for parquet_file in parquet_files: # Load a subset from each parquet file diff --git a/src/aiunn/finetune/trainer.py b/src/aiunn/finetune/trainer.py index 01047b9..b94d57d 100644 --- a/src/aiunn/finetune/trainer.py +++ b/src/aiunn/finetune/trainer.py @@ -260,7 +260,8 @@ class aiuNNTrainer: torch.cuda.empty_cache() # Check early stopping - if early_stopping(val_loss if self.validation_loader else avg_train_loss): + early_stopping(val_loss if self.validation_loader else avg_train_loss) + if early_stopping.early_stop: print(f"Early stopping triggered at epoch {epoch + 1}") break