diff --git a/example.py b/example.py index b2c26c8..58470b0 100644 --- a/example.py +++ b/example.py @@ -11,7 +11,7 @@ from torchvision import transforms class UpscaleDataset(Dataset): - def __init__(self, parquet_files: list, transform=None, samples_per_file=2500): + def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000): combined_df = pd.DataFrame() for parquet_file in parquet_files: # Load a subset from each parquet file @@ -68,8 +68,8 @@ class UpscaleDataset(Dataset): high_res = high_res_rgb # Resize the images to reduce VRAM usage - low_res = low_res.resize((384, 384), Image.LANCZOS) - high_res = high_res.resize((768, 768), Image.LANCZOS) + low_res = low_res.resize((410, 410), Image.LANCZOS) + high_res = high_res.resize((820, 820), Image.LANCZOS) # If a transform is provided (e.g. conversion to Tensor), apply it if self.transform: diff --git a/src/aiunn/finetune/trainer.py b/src/aiunn/finetune/trainer.py index 01047b9..b94d57d 100644 --- a/src/aiunn/finetune/trainer.py +++ b/src/aiunn/finetune/trainer.py @@ -260,7 +260,8 @@ class aiuNNTrainer: torch.cuda.empty_cache() # Check early stopping - if early_stopping(val_loss if self.validation_loader else avg_train_loss): + early_stopping(val_loss if self.validation_loader else avg_train_loss) + if early_stopping.early_stop: print(f"Early stopping triggered at epoch {epoch + 1}") break