develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 50 additions and 18 deletions
Showing only changes of commit fca74fb8d2 - Show all commits

View File

@ -1,46 +1,78 @@
import pandas as pd import pandas as pd
import io import io
from PIL import Image from PIL import Image, ImageFile
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from aiia import AIIABase from aiia import AIIABase
import csv import csv
from tqdm import tqdm from tqdm import tqdm
import base64
class UpscaleDataset(Dataset): class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None): def __init__(self, parquet_files: list, transform=None):
# Initialize an empty DataFrame to hold the combined data
combined_df = pd.DataFrame() combined_df = pd.DataFrame()
# Iterate through each Parquet file in the list and load it into a DataFrame
for parquet_file in parquet_files: for parquet_file in parquet_files:
# Load data with chunking for memory efficiency
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000) df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000)
combined_df = pd.concat([combined_df, df], ignore_index=True) combined_df = pd.concat([combined_df, df], ignore_index=True)
self.df = combined_df
# Validate data format
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
"""Ensure both images exist and have correct dimensions"""
for col in ['image_512', 'image_1024']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
"""Universal decoder handling both base64 strings and bytes"""
try:
if isinstance(data, str):
# Handle base64 encoded strings
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
raise RuntimeError(f"Decoding failed: {str(e)}")
def __len__(self): def __len__(self):
return len(self.df) return len(self.df)
def __getitem__(self, idx): def __getitem__(self, idx):
if idx in self.failed_indices:
return self[(idx + 1) % len(self)] # Skip failed indices
try: try:
row = self.df.iloc[idx] row = self.df.iloc[idx]
# Convert string to bytes if necessary
low_res_bytes = row['image_512'].encode('latin-1') if isinstance(row['image_512'], str) else row['image_512']
high_res_bytes = row['image_1024'].encode('latin-1') if isinstance(row['image_1024'], str) else row['image_1024']
# Decode the bytes into images
low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
# Decode both images
low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024'])
# Load images with truncation handling
ImageFile.LOAD_TRUNCATED_IMAGES = True
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
# Validate image sizes
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
if self.transform: if self.transform:
low_res_image = self.transform(low_res_image) low_res = self.transform(low_res)
high_res_image = self.transform(high_res_image) high_res = self.transform(high_res)
return low_res_image, high_res_image
return low_res, high_res
except Exception as e: except Exception as e:
print(f"Error processing index {idx}: {str(e)}") print(f"\nError at index {idx}: {str(e)}")
# You might want to either skip this sample or return a default value self.failed_indices.add(idx)
raise e return self[(idx + 1) % len(self)] # Return next valid sample
# Example transform: converting PIL images to tensors # Example transform: converting PIL images to tensors
transform = transforms.Compose([ transform = transforms.Compose([