new loading

This commit is contained in:
Falko Victor Habel 2025-02-21 21:12:14 +01:00
parent 2583d2f01f
commit fca74fb8d2
1 changed files with 50 additions and 18 deletions

View File

@ -1,46 +1,78 @@
import pandas as pd
import io
from PIL import Image
from PIL import Image, ImageFile
from torch.utils.data import Dataset
from torchvision import transforms
from aiia import AIIABase
import csv
from tqdm import tqdm
import base64
class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None):
# Initialize an empty DataFrame to hold the combined data
combined_df = pd.DataFrame()
# Iterate through each Parquet file in the list and load it into a DataFrame
for parquet_file in parquet_files:
# Load data with chunking for memory efficiency
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000)
combined_df = pd.concat([combined_df, df], ignore_index=True)
self.df = combined_df
# Validate data format
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
"""Ensure both images exist and have correct dimensions"""
for col in ['image_512', 'image_1024']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
"""Universal decoder handling both base64 strings and bytes"""
try:
if isinstance(data, str):
# Handle base64 encoded strings
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
raise RuntimeError(f"Decoding failed: {str(e)}")
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if idx in self.failed_indices:
return self[(idx + 1) % len(self)] # Skip failed indices
try:
row = self.df.iloc[idx]
# Convert string to bytes if necessary
low_res_bytes = row['image_512'].encode('latin-1') if isinstance(row['image_512'], str) else row['image_512']
high_res_bytes = row['image_1024'].encode('latin-1') if isinstance(row['image_1024'], str) else row['image_1024']
# Decode the bytes into images
low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
# Decode both images
low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024'])
# Load images with truncation handling
ImageFile.LOAD_TRUNCATED_IMAGES = True
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
# Validate image sizes
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return low_res_image, high_res_image
low_res = self.transform(low_res)
high_res = self.transform(high_res)
return low_res, high_res
except Exception as e:
print(f"Error processing index {idx}: {str(e)}")
# You might want to either skip this sample or return a default value
raise e
print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)] # Return next valid sample
# Example transform: converting PIL images to tensors
transform = transforms.Compose([