develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 60 additions and 22 deletions
Showing only changes of commit bac4a9010a - Show all commits

View File

@ -1,13 +1,13 @@
import torch import torch
import pandas as pd import pandas as pd
import numpy as np
import cv2
import os
from albumentations import ( from albumentations import (
Compose, Resize, Normalize, RandomBrightnessContrast, Compose, Resize, Normalize, RandomBrightnessContrast,
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
) )
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from PIL import Image, ImageFile
import io
import base64
from torch import nn from torch import nn
# Import the model and config from your existing code # Import the model and config from your existing code
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
@ -19,12 +19,12 @@ class aiuNNDataset(torch.utils.data.Dataset):
# Data augmentation pipeline # Data augmentation pipeline
self.augmentation = Compose([ self.augmentation = Compose([
Resize(height=512, width=512), Resize((512, 512)),
RandomBrightnessContrast(), RandomBrightnessContrast(),
HorizontalFlip(p=0.5), HorizontalFlip(p=0.5),
VerticalFlip(p=0.5), VerticalFlip(p=0.5),
Rotate(limit=45), Rotate(degrees=45),
GaussianBlur(p=0.3), GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
Normalize(mean=[0.5], std=[0.5]), Normalize(mean=[0.5], std=[0.5]),
ToTensorV2() ToTensorV2()
]) ])
@ -32,17 +32,56 @@ class aiuNNDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self.df) return len(self.df)
def __getitem__(self, idx): def load_image(self, image_data):
# Get the byte strings try:
low_res_bytes = self.df.iloc[idx]['image_512'] # Handle both bytes and base64 encoded strings
high_res_bytes = self.df.iloc[idx]['image_1024'] if isinstance(image_data, str):
# Decode base64 string to bytes
image_data = base64.b64decode(image_data)
# Convert bytes to numpy arrays # Verify data is valid before creating BytesIO
low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1) if not isinstance(image_data, bytes):
high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1) raise ValueError("Invalid image data format")
# Create image stream
image_stream = io.BytesIO(image_data)
# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert image to RGB
image = Image.open(image_stream).convert('RGB')
# Create fresh copy for verify() since it modifies the image object
image_verify = image.copy()
# Verify image is valid
try:
image_verify.verify()
except Exception as e:
raise ValueError(f"Image verification failed: {str(e)}")
finally:
image_verify.close()
return image
except Exception as e:
raise RuntimeError(f"Error loading image: {str(e)}")
finally:
# Ensure stream is closed
if 'image_stream' in locals():
image_stream.close()
def __getitem__(self, idx):
row = self.df.iloc[idx]
# Load images using the new method
low_res_image = self.load_image(row['image_512'])
high_res_image = self.load_image(row['image_1024'])
# Apply augmentation and normalization # Apply augmentation and normalization
augmented = self.augmentation(image=low_res, mask=high_res) augmented = self.augmentation(image=low_res_image, mask=high_res_image)
low_res = augmented['image'] low_res = augmented['image']
high_res = augmented['mask'] high_res = augmented['mask']
@ -51,8 +90,7 @@ class aiuNNDataset(torch.utils.data.Dataset):
'high_res': high_res 'high_res': high_res
} }
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10):
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10):
# Initialize dataset and dataloader # Initialize dataset and dataloader
train_dataset = aiuNNDataset(train_parquet_path) train_dataset = aiuNNDataset(train_parquet_path)
val_dataset = aiuNNDataset(val_parquet_path) val_dataset = aiuNNDataset(val_parquet_path)