develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 60 additions and 22 deletions
Showing only changes of commit bac4a9010a - Show all commits

View File

@ -1,58 +1,96 @@
import torch import torch
import pandas as pd import pandas as pd
import numpy as np
import cv2
import os
from albumentations import ( from albumentations import (
Compose, Resize, Normalize, RandomBrightnessContrast, Compose, Resize, Normalize, RandomBrightnessContrast,
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
) )
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from PIL import Image, ImageFile
import io
import base64
from torch import nn from torch import nn
# Import the model and config from your existing code # Import the model and config from your existing code
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
class aiuNNDataset(torch.utils.data.Dataset): class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path, config=None): def __init__(self, parquet_path, config=None):
# Read the Parquet file # Read the Parquet file
self.df = pd.read_parquet(parquet_path) self.df = pd.read_parquet(parquet_path)
# Data augmentation pipeline # Data augmentation pipeline
self.augmentation = Compose([ self.augmentation = Compose([
Resize(height=512, width=512), Resize((512, 512)),
RandomBrightnessContrast(), RandomBrightnessContrast(),
HorizontalFlip(p=0.5), HorizontalFlip(p=0.5),
VerticalFlip(p=0.5), VerticalFlip(p=0.5),
Rotate(limit=45), Rotate(degrees=45),
GaussianBlur(p=0.3), GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
Normalize(mean=[0.5], std=[0.5]), Normalize(mean=[0.5], std=[0.5]),
ToTensorV2() ToTensorV2()
]) ])
def __len__(self): def __len__(self):
return len(self.df) return len(self.df)
def load_image(self, image_data):
try:
# Handle both bytes and base64 encoded strings
if isinstance(image_data, str):
# Decode base64 string to bytes
image_data = base64.b64decode(image_data)
# Verify data is valid before creating BytesIO
if not isinstance(image_data, bytes):
raise ValueError("Invalid image data format")
# Create image stream
image_stream = io.BytesIO(image_data)
# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert image to RGB
image = Image.open(image_stream).convert('RGB')
# Create fresh copy for verify() since it modifies the image object
image_verify = image.copy()
# Verify image is valid
try:
image_verify.verify()
except Exception as e:
raise ValueError(f"Image verification failed: {str(e)}")
finally:
image_verify.close()
return image
except Exception as e:
raise RuntimeError(f"Error loading image: {str(e)}")
finally:
# Ensure stream is closed
if 'image_stream' in locals():
image_stream.close()
def __getitem__(self, idx): def __getitem__(self, idx):
# Get the byte strings row = self.df.iloc[idx]
low_res_bytes = self.df.iloc[idx]['image_512']
high_res_bytes = self.df.iloc[idx]['image_1024'] # Load images using the new method
low_res_image = self.load_image(row['image_512'])
# Convert bytes to numpy arrays high_res_image = self.load_image(row['image_1024'])
low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1)
high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1)
# Apply augmentation and normalization # Apply augmentation and normalization
augmented = self.augmentation(image=low_res, mask=high_res) augmented = self.augmentation(image=low_res_image, mask=high_res_image)
low_res = augmented['image'] low_res = augmented['image']
high_res = augmented['mask'] high_res = augmented['mask']
return { return {
'low_res': low_res, 'low_res': low_res,
'high_res': high_res 'high_res': high_res
} }
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10):
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10):
# Initialize dataset and dataloader # Initialize dataset and dataloader
train_dataset = aiuNNDataset(train_parquet_path) train_dataset = aiuNNDataset(train_parquet_path)
val_dataset = aiuNNDataset(val_parquet_path) val_dataset = aiuNNDataset(val_parquet_path)