develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 60 additions and 22 deletions
Showing only changes of commit bac4a9010a - Show all commits

View File

@ -1,13 +1,13 @@
import torch
import pandas as pd
import numpy as np
import cv2
import os
from albumentations import (
Compose, Resize, Normalize, RandomBrightnessContrast,
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
)
from albumentations.pytorch import ToTensorV2
from PIL import Image, ImageFile
import io
import base64
from torch import nn
# Import the model and config from your existing code
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
@ -19,12 +19,12 @@ class aiuNNDataset(torch.utils.data.Dataset):
# Data augmentation pipeline
self.augmentation = Compose([
Resize(height=512, width=512),
Resize((512, 512)),
RandomBrightnessContrast(),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
Rotate(limit=45),
GaussianBlur(p=0.3),
Rotate(degrees=45),
GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
Normalize(mean=[0.5], std=[0.5]),
ToTensorV2()
])
@ -32,17 +32,56 @@ class aiuNNDataset(torch.utils.data.Dataset):
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# Get the byte strings
low_res_bytes = self.df.iloc[idx]['image_512']
high_res_bytes = self.df.iloc[idx]['image_1024']
def load_image(self, image_data):
try:
# Handle both bytes and base64 encoded strings
if isinstance(image_data, str):
# Decode base64 string to bytes
image_data = base64.b64decode(image_data)
# Convert bytes to numpy arrays
low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1)
high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1)
# Verify data is valid before creating BytesIO
if not isinstance(image_data, bytes):
raise ValueError("Invalid image data format")
# Create image stream
image_stream = io.BytesIO(image_data)
# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert image to RGB
image = Image.open(image_stream).convert('RGB')
# Create fresh copy for verify() since it modifies the image object
image_verify = image.copy()
# Verify image is valid
try:
image_verify.verify()
except Exception as e:
raise ValueError(f"Image verification failed: {str(e)}")
finally:
image_verify.close()
return image
except Exception as e:
raise RuntimeError(f"Error loading image: {str(e)}")
finally:
# Ensure stream is closed
if 'image_stream' in locals():
image_stream.close()
def __getitem__(self, idx):
row = self.df.iloc[idx]
# Load images using the new method
low_res_image = self.load_image(row['image_512'])
high_res_image = self.load_image(row['image_1024'])
# Apply augmentation and normalization
augmented = self.augmentation(image=low_res, mask=high_res)
augmented = self.augmentation(image=low_res_image, mask=high_res_image)
low_res = augmented['image']
high_res = augmented['mask']
@ -51,8 +90,7 @@ class aiuNNDataset(torch.utils.data.Dataset):
'high_res': high_res
}
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10):
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10):
# Initialize dataset and dataloader
train_dataset = aiuNNDataset(train_parquet_path)
val_dataset = aiuNNDataset(val_parquet_path)