diff --git a/example.py b/example.py index b2c26c8..d03e931 100644 --- a/example.py +++ b/example.py @@ -11,7 +11,7 @@ from torchvision import transforms class UpscaleDataset(Dataset): - def __init__(self, parquet_files: list, transform=None, samples_per_file=2500): + def __init__(self, parquet_files: list, transform=None, samples_per_file=5000): combined_df = pd.DataFrame() for parquet_file in parquet_files: # Load a subset from each parquet file diff --git a/src/aiunn/finetune/trainer.py b/src/aiunn/finetune/trainer.py index 01047b9..b94d57d 100644 --- a/src/aiunn/finetune/trainer.py +++ b/src/aiunn/finetune/trainer.py @@ -260,7 +260,8 @@ class aiuNNTrainer: torch.cuda.empty_cache() # Check early stopping - if early_stopping(val_loss if self.validation_loader else avg_train_loss): + early_stopping(val_loss if self.validation_loader else avg_train_loss) + if early_stopping.early_stop: print(f"Early stopping triggered at epoch {epoch + 1}") break