updated dataset loading

This commit is contained in:
Falko Victor Habel 2025-02-06 18:33:43 +01:00
parent 72a5959bc1
commit bac4a9010a
1 changed files with 60 additions and 22 deletions

View File

@ -1,58 +1,96 @@
import torch
import pandas as pd
import numpy as np
import cv2
import os
from albumentations import (
Compose, Resize, Normalize, RandomBrightnessContrast,
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
)
from albumentations.pytorch import ToTensorV2
from PIL import Image, ImageFile
import io
import base64
from torch import nn
# Import the model and config from your existing code
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path, config=None):
def __init__(self, parquet_path, config=None):
# Read the Parquet file
self.df = pd.read_parquet(parquet_path)
# Data augmentation pipeline
self.augmentation = Compose([
Resize(height=512, width=512),
Resize((512, 512)),
RandomBrightnessContrast(),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
Rotate(limit=45),
GaussianBlur(p=0.3),
Rotate(degrees=45),
GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
Normalize(mean=[0.5], std=[0.5]),
ToTensorV2()
])
def __len__(self):
return len(self.df)
def load_image(self, image_data):
try:
# Handle both bytes and base64 encoded strings
if isinstance(image_data, str):
# Decode base64 string to bytes
image_data = base64.b64decode(image_data)
# Verify data is valid before creating BytesIO
if not isinstance(image_data, bytes):
raise ValueError("Invalid image data format")
# Create image stream
image_stream = io.BytesIO(image_data)
# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert image to RGB
image = Image.open(image_stream).convert('RGB')
# Create fresh copy for verify() since it modifies the image object
image_verify = image.copy()
# Verify image is valid
try:
image_verify.verify()
except Exception as e:
raise ValueError(f"Image verification failed: {str(e)}")
finally:
image_verify.close()
return image
except Exception as e:
raise RuntimeError(f"Error loading image: {str(e)}")
finally:
# Ensure stream is closed
if 'image_stream' in locals():
image_stream.close()
def __getitem__(self, idx):
# Get the byte strings
low_res_bytes = self.df.iloc[idx]['image_512']
high_res_bytes = self.df.iloc[idx]['image_1024']
# Convert bytes to numpy arrays
low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1)
high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1)
row = self.df.iloc[idx]
# Load images using the new method
low_res_image = self.load_image(row['image_512'])
high_res_image = self.load_image(row['image_1024'])
# Apply augmentation and normalization
augmented = self.augmentation(image=low_res, mask=high_res)
augmented = self.augmentation(image=low_res_image, mask=high_res_image)
low_res = augmented['image']
high_res = augmented['mask']
return {
'low_res': low_res,
'high_res': high_res
}
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10):
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10):
# Initialize dataset and dataloader
train_dataset = aiuNNDataset(train_parquet_path)
val_dataset = aiuNNDataset(val_parquet_path)