diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 80dad0b..3e0441f 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -8,24 +8,39 @@ import csv from tqdm import tqdm class UpscaleDataset(Dataset): - def __init__(self, parquet_file, transform=None): - self.df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000) + def __init__(self, parquet_files: list, transform=None): + # Initialize an empty DataFrame to hold the combined data + combined_df = pd.DataFrame() + + # Iterate through each Parquet file in the list and load it into a DataFrame + for parquet_file in parquet_files: + df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000) + combined_df = pd.concat([combined_df, df], ignore_index=True) + self.transform = transform def __len__(self): return len(self.df) def __getitem__(self, idx): - row = self.df.iloc[idx] - # Decode the byte strings into images - low_res_bytes = row['image_512'] - high_res_bytes = row['image_1024'] - low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') - high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') - if self.transform: - low_res_image = self.transform(low_res_image) - high_res_image = self.transform(high_res_image) - return low_res_image, high_res_image + try: + row = self.df.iloc[idx] + # Convert string to bytes if necessary + low_res_bytes = row['image_512'].encode('latin-1') if isinstance(row['image_512'], str) else row['image_512'] + high_res_bytes = row['image_1024'].encode('latin-1') if isinstance(row['image_1024'], str) else row['image_1024'] + + # Decode the bytes into images + low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') + high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') + + if self.transform: + low_res_image = self.transform(low_res_image) + high_res_image = self.transform(high_res_image) + return low_res_image, high_res_image + except Exception as e: + print(f"Error processing index {idx}: {str(e)}") + # You might want to either skip this sample or return a default value + raise e # Example transform: converting PIL images to tensors transform = transforms.Compose([ @@ -46,7 +61,7 @@ from torch import nn, optim from torch.utils.data import DataLoader # Create your dataset and dataloader -dataset = UpscaleDataset("/root/training_data/vision-dataset/image_upscaler.parquet", transform=transform) +dataset = UpscaleDataset(["/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet"], transform=transform) data_loader = DataLoader(dataset, batch_size=16, shuffle=True) # Define a loss function and optimizer