diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 34a08c3..23434cd 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -1,58 +1,96 @@ import torch import pandas as pd -import numpy as np -import cv2 -import os from albumentations import ( Compose, Resize, Normalize, RandomBrightnessContrast, HorizontalFlip, VerticalFlip, Rotate, GaussianBlur ) from albumentations.pytorch import ToTensorV2 +from PIL import Image, ImageFile +import io +import base64 from torch import nn # Import the model and config from your existing code from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive class aiuNNDataset(torch.utils.data.Dataset): - def __init__(self, parquet_path, config=None): + def __init__(self, parquet_path, config=None): # Read the Parquet file self.df = pd.read_parquet(parquet_path) - + # Data augmentation pipeline self.augmentation = Compose([ - Resize(height=512, width=512), + Resize((512, 512)), RandomBrightnessContrast(), HorizontalFlip(p=0.5), VerticalFlip(p=0.5), - Rotate(limit=45), - GaussianBlur(p=0.3), + Rotate(degrees=45), + GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), Normalize(mean=[0.5], std=[0.5]), ToTensorV2() ]) - + def __len__(self): return len(self.df) - + + def load_image(self, image_data): + try: + # Handle both bytes and base64 encoded strings + if isinstance(image_data, str): + # Decode base64 string to bytes + image_data = base64.b64decode(image_data) + + # Verify data is valid before creating BytesIO + if not isinstance(image_data, bytes): + raise ValueError("Invalid image data format") + + # Create image stream + image_stream = io.BytesIO(image_data) + + # Enable loading of truncated images + ImageFile.LOAD_TRUNCATED_IMAGES = True + + # Load and convert image to RGB + image = Image.open(image_stream).convert('RGB') + + # Create fresh copy for verify() since it modifies the image object + image_verify = image.copy() + + # Verify image is valid + try: + image_verify.verify() + except Exception as e: + raise ValueError(f"Image verification failed: {str(e)}") + finally: + image_verify.close() + + return image + + except Exception as e: + raise RuntimeError(f"Error loading image: {str(e)}") + + finally: + # Ensure stream is closed + if 'image_stream' in locals(): + image_stream.close() + def __getitem__(self, idx): - # Get the byte strings - low_res_bytes = self.df.iloc[idx]['image_512'] - high_res_bytes = self.df.iloc[idx]['image_1024'] - - # Convert bytes to numpy arrays - low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1) - high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1) - + row = self.df.iloc[idx] + + # Load images using the new method + low_res_image = self.load_image(row['image_512']) + high_res_image = self.load_image(row['image_1024']) + # Apply augmentation and normalization - augmented = self.augmentation(image=low_res, mask=high_res) + augmented = self.augmentation(image=low_res_image, mask=high_res_image) low_res = augmented['image'] high_res = augmented['mask'] - + return { 'low_res': low_res, 'high_res': high_res } - -def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10): +def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10): # Initialize dataset and dataloader train_dataset = aiuNNDataset(train_parquet_path) val_dataset = aiuNNDataset(val_parquet_path)