new loading
This commit is contained in:
parent
2583d2f01f
commit
fca74fb8d2
|
@ -1,46 +1,78 @@
|
|||
import pandas as pd
|
||||
import io
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFile
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from aiia import AIIABase
|
||||
import csv
|
||||
from tqdm import tqdm
|
||||
import base64
|
||||
|
||||
class UpscaleDataset(Dataset):
|
||||
def __init__(self, parquet_files: list, transform=None):
|
||||
# Initialize an empty DataFrame to hold the combined data
|
||||
combined_df = pd.DataFrame()
|
||||
|
||||
# Iterate through each Parquet file in the list and load it into a DataFrame
|
||||
for parquet_file in parquet_files:
|
||||
# Load data with chunking for memory efficiency
|
||||
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000)
|
||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||
self.df = combined_df
|
||||
|
||||
# Validate data format
|
||||
self.df = combined_df.apply(self._validate_row, axis=1)
|
||||
self.transform = transform
|
||||
self.failed_indices = set()
|
||||
|
||||
def _validate_row(self, row):
|
||||
"""Ensure both images exist and have correct dimensions"""
|
||||
for col in ['image_512', 'image_1024']:
|
||||
if not isinstance(row[col], (bytes, str)):
|
||||
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
|
||||
return row
|
||||
|
||||
def _decode_image(self, data):
|
||||
"""Universal decoder handling both base64 strings and bytes"""
|
||||
try:
|
||||
if isinstance(data, str):
|
||||
# Handle base64 encoded strings
|
||||
return base64.b64decode(data)
|
||||
elif isinstance(data, bytes):
|
||||
return data
|
||||
raise ValueError(f"Unsupported data type: {type(data)}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Decoding failed: {str(e)}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx in self.failed_indices:
|
||||
return self[(idx + 1) % len(self)] # Skip failed indices
|
||||
|
||||
try:
|
||||
row = self.df.iloc[idx]
|
||||
# Convert string to bytes if necessary
|
||||
low_res_bytes = row['image_512'].encode('latin-1') if isinstance(row['image_512'], str) else row['image_512']
|
||||
high_res_bytes = row['image_1024'].encode('latin-1') if isinstance(row['image_1024'], str) else row['image_1024']
|
||||
|
||||
# Decode the bytes into images
|
||||
low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
|
||||
high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
|
||||
|
||||
# Decode both images
|
||||
low_res_bytes = self._decode_image(row['image_512'])
|
||||
high_res_bytes = self._decode_image(row['image_1024'])
|
||||
|
||||
# Load images with truncation handling
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
|
||||
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
|
||||
|
||||
# Validate image sizes
|
||||
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
|
||||
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
|
||||
|
||||
if self.transform:
|
||||
low_res_image = self.transform(low_res_image)
|
||||
high_res_image = self.transform(high_res_image)
|
||||
return low_res_image, high_res_image
|
||||
low_res = self.transform(low_res)
|
||||
high_res = self.transform(high_res)
|
||||
|
||||
return low_res, high_res
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing index {idx}: {str(e)}")
|
||||
# You might want to either skip this sample or return a default value
|
||||
raise e
|
||||
print(f"\nError at index {idx}: {str(e)}")
|
||||
self.failed_indices.add(idx)
|
||||
return self[(idx + 1) % len(self)] # Return next valid sample
|
||||
|
||||
# Example transform: converting PIL images to tensors
|
||||
transform = transforms.Compose([
|
||||
|
|
Loading…
Reference in New Issue