diff --git a/example.py b/example.py index d03e931..58470b0 100644 --- a/example.py +++ b/example.py @@ -11,7 +11,7 @@ from torchvision import transforms class UpscaleDataset(Dataset): - def __init__(self, parquet_files: list, transform=None, samples_per_file=5000): + def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000): combined_df = pd.DataFrame() for parquet_file in parquet_files: # Load a subset from each parquet file @@ -68,8 +68,8 @@ class UpscaleDataset(Dataset): high_res = high_res_rgb # Resize the images to reduce VRAM usage - low_res = low_res.resize((384, 384), Image.LANCZOS) - high_res = high_res.resize((768, 768), Image.LANCZOS) + low_res = low_res.resize((410, 410), Image.LANCZOS) + high_res = high_res.resize((820, 820), Image.LANCZOS) # If a transform is provided (e.g. conversion to Tensor), apply it if self.transform: