From 7fd2af6f1242c53af86ca8a61cbc5a1ea43382a3 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 25 Feb 2025 19:47:18 +0100 Subject: [PATCH 1/2] improved early stoppping --- example.py | 2 +- src/aiunn/finetune/trainer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/example.py b/example.py index b2c26c8..d03e931 100644 --- a/example.py +++ b/example.py @@ -11,7 +11,7 @@ from torchvision import transforms class UpscaleDataset(Dataset): - def __init__(self, parquet_files: list, transform=None, samples_per_file=2500): + def __init__(self, parquet_files: list, transform=None, samples_per_file=5000): combined_df = pd.DataFrame() for parquet_file in parquet_files: # Load a subset from each parquet file diff --git a/src/aiunn/finetune/trainer.py b/src/aiunn/finetune/trainer.py index 01047b9..b94d57d 100644 --- a/src/aiunn/finetune/trainer.py +++ b/src/aiunn/finetune/trainer.py @@ -260,7 +260,8 @@ class aiuNNTrainer: torch.cuda.empty_cache() # Check early stopping - if early_stopping(val_loss if self.validation_loader else avg_train_loss): + early_stopping(val_loss if self.validation_loader else avg_train_loss) + if early_stopping.early_stop: print(f"Early stopping triggered at epoch {epoch + 1}") break From ed80b0b06852e3450fe3d44c2afe57a9cabd66c5 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 26 Feb 2025 12:33:25 +0100 Subject: [PATCH 2/2] improved quality --- example.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example.py b/example.py index d03e931..58470b0 100644 --- a/example.py +++ b/example.py @@ -11,7 +11,7 @@ from torchvision import transforms class UpscaleDataset(Dataset): - def __init__(self, parquet_files: list, transform=None, samples_per_file=5000): + def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000): combined_df = pd.DataFrame() for parquet_file in parquet_files: # Load a subset from each parquet file @@ -68,8 +68,8 @@ class UpscaleDataset(Dataset): high_res = high_res_rgb # Resize the images to reduce VRAM usage - low_res = low_res.resize((384, 384), Image.LANCZOS) - high_res = high_res.resize((768, 768), Image.LANCZOS) + low_res = low_res.resize((410, 410), Image.LANCZOS) + high_res = high_res.resize((820, 820), Image.LANCZOS) # If a transform is provided (e.g. conversion to Tensor), apply it if self.transform: