downsized imageset

This commit is contained in:
Falko Victor Habel 2025-02-24 16:20:45 +01:00
parent 443b9f5589
commit 2360e23cc7
1 changed files with 16 additions and 6 deletions

View File

@ -40,9 +40,11 @@ class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
# Load a subset (head(2500)) from each parquet file
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(2500)
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate rows (ensuring each value is bytes or str)
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
@ -67,6 +69,7 @@ class UpscaleDataset(Dataset):
return len(self.df)
def __getitem__(self, idx):
# If previous call failed for this index, use a different index.
if idx in self.failed_indices:
return self[(idx + 1) % len(self)]
try:
@ -74,10 +77,17 @@ class UpscaleDataset(Dataset):
low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024'])
ImageFile.LOAD_TRUNCATED_IMAGES = True
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
# Open image bytes with Pillow and convert to RGBA
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
# Resize the images to reduce VRAM usage.
# Using Image.ANTIALIAS which is equivalent to LANCZOS in current Pillow versions.
low_res = low_res.resize((384, 384), Image.ANTIALIAS)
high_res = high_res.resize((768, 768), Image.ANTIALIAS)
# If a transform is provided (e.g. conversion to Tensor), apply it.
if self.transform:
low_res = self.transform(low_res)
high_res = self.transform(high_res)
@ -97,7 +107,7 @@ pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
model = Upsampler(base_model)
device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move model to device using channels_last memory format.
model = model.to(device, memory_format=torch.channels_last)