improved early stoppping
This commit is contained in:
parent
2e1556d306
commit
7fd2af6f12
|
@ -11,7 +11,7 @@ from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
class UpscaleDataset(Dataset):
|
class UpscaleDataset(Dataset):
|
||||||
def __init__(self, parquet_files: list, transform=None, samples_per_file=2500):
|
def __init__(self, parquet_files: list, transform=None, samples_per_file=5000):
|
||||||
combined_df = pd.DataFrame()
|
combined_df = pd.DataFrame()
|
||||||
for parquet_file in parquet_files:
|
for parquet_file in parquet_files:
|
||||||
# Load a subset from each parquet file
|
# Load a subset from each parquet file
|
||||||
|
|
|
@ -260,7 +260,8 @@ class aiuNNTrainer:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Check early stopping
|
# Check early stopping
|
||||||
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
|
early_stopping(val_loss if self.validation_loader else avg_train_loss)
|
||||||
|
if early_stopping.early_stop:
|
||||||
print(f"Early stopping triggered at epoch {epoch + 1}")
|
print(f"Early stopping triggered at epoch {epoch + 1}")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue