Merge pull request 'overall_performance_improvement' (#3) from overall_performance_improvement into develop
Reviewed-on: #3
This commit is contained in:
commit
81eaebac5b
|
@ -11,7 +11,7 @@ from torchvision import transforms
|
|||
|
||||
|
||||
class UpscaleDataset(Dataset):
|
||||
def __init__(self, parquet_files: list, transform=None, samples_per_file=2500):
|
||||
def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
|
||||
combined_df = pd.DataFrame()
|
||||
for parquet_file in parquet_files:
|
||||
# Load a subset from each parquet file
|
||||
|
@ -68,8 +68,8 @@ class UpscaleDataset(Dataset):
|
|||
high_res = high_res_rgb
|
||||
|
||||
# Resize the images to reduce VRAM usage
|
||||
low_res = low_res.resize((384, 384), Image.LANCZOS)
|
||||
high_res = high_res.resize((768, 768), Image.LANCZOS)
|
||||
low_res = low_res.resize((410, 410), Image.LANCZOS)
|
||||
high_res = high_res.resize((820, 820), Image.LANCZOS)
|
||||
|
||||
# If a transform is provided (e.g. conversion to Tensor), apply it
|
||||
if self.transform:
|
||||
|
|
|
@ -260,7 +260,8 @@ class aiuNNTrainer:
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
# Check early stopping
|
||||
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
|
||||
early_stopping(val_loss if self.validation_loader else avg_train_loss)
|
||||
if early_stopping.early_stop:
|
||||
print(f"Early stopping triggered at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
|
|
Loading…
Reference in New Issue