improved quality
This commit is contained in:
parent
7fd2af6f12
commit
ed80b0b068
|
@ -11,7 +11,7 @@ from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
class UpscaleDataset(Dataset):
|
class UpscaleDataset(Dataset):
|
||||||
def __init__(self, parquet_files: list, transform=None, samples_per_file=5000):
|
def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
|
||||||
combined_df = pd.DataFrame()
|
combined_df = pd.DataFrame()
|
||||||
for parquet_file in parquet_files:
|
for parquet_file in parquet_files:
|
||||||
# Load a subset from each parquet file
|
# Load a subset from each parquet file
|
||||||
|
@ -68,8 +68,8 @@ class UpscaleDataset(Dataset):
|
||||||
high_res = high_res_rgb
|
high_res = high_res_rgb
|
||||||
|
|
||||||
# Resize the images to reduce VRAM usage
|
# Resize the images to reduce VRAM usage
|
||||||
low_res = low_res.resize((384, 384), Image.LANCZOS)
|
low_res = low_res.resize((410, 410), Image.LANCZOS)
|
||||||
high_res = high_res.resize((768, 768), Image.LANCZOS)
|
high_res = high_res.resize((820, 820), Image.LANCZOS)
|
||||||
|
|
||||||
# If a transform is provided (e.g. conversion to Tensor), apply it
|
# If a transform is provided (e.g. conversion to Tensor), apply it
|
||||||
if self.transform:
|
if self.transform:
|
||||||
|
|
Loading…
Reference in New Issue