finetune_class #1
|
@ -1,13 +1,13 @@
|
||||||
import torch
|
import torch
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import os
|
|
||||||
from albumentations import (
|
from albumentations import (
|
||||||
Compose, Resize, Normalize, RandomBrightnessContrast,
|
Compose, Resize, Normalize, RandomBrightnessContrast,
|
||||||
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
|
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
|
||||||
)
|
)
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
from torch import nn
|
from torch import nn
|
||||||
# Import the model and config from your existing code
|
# Import the model and config from your existing code
|
||||||
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
||||||
|
@ -19,12 +19,12 @@ class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
# Data augmentation pipeline
|
# Data augmentation pipeline
|
||||||
self.augmentation = Compose([
|
self.augmentation = Compose([
|
||||||
Resize(height=512, width=512),
|
Resize((512, 512)),
|
||||||
RandomBrightnessContrast(),
|
RandomBrightnessContrast(),
|
||||||
HorizontalFlip(p=0.5),
|
HorizontalFlip(p=0.5),
|
||||||
VerticalFlip(p=0.5),
|
VerticalFlip(p=0.5),
|
||||||
Rotate(limit=45),
|
Rotate(degrees=45),
|
||||||
GaussianBlur(p=0.3),
|
GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
|
||||||
Normalize(mean=[0.5], std=[0.5]),
|
Normalize(mean=[0.5], std=[0.5]),
|
||||||
ToTensorV2()
|
ToTensorV2()
|
||||||
])
|
])
|
||||||
|
@ -32,17 +32,56 @@ class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.df)
|
return len(self.df)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def load_image(self, image_data):
|
||||||
# Get the byte strings
|
try:
|
||||||
low_res_bytes = self.df.iloc[idx]['image_512']
|
# Handle both bytes and base64 encoded strings
|
||||||
high_res_bytes = self.df.iloc[idx]['image_1024']
|
if isinstance(image_data, str):
|
||||||
|
# Decode base64 string to bytes
|
||||||
|
image_data = base64.b64decode(image_data)
|
||||||
|
|
||||||
# Convert bytes to numpy arrays
|
# Verify data is valid before creating BytesIO
|
||||||
low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1)
|
if not isinstance(image_data, bytes):
|
||||||
high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1)
|
raise ValueError("Invalid image data format")
|
||||||
|
|
||||||
|
# Create image stream
|
||||||
|
image_stream = io.BytesIO(image_data)
|
||||||
|
|
||||||
|
# Enable loading of truncated images
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
# Load and convert image to RGB
|
||||||
|
image = Image.open(image_stream).convert('RGB')
|
||||||
|
|
||||||
|
# Create fresh copy for verify() since it modifies the image object
|
||||||
|
image_verify = image.copy()
|
||||||
|
|
||||||
|
# Verify image is valid
|
||||||
|
try:
|
||||||
|
image_verify.verify()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Image verification failed: {str(e)}")
|
||||||
|
finally:
|
||||||
|
image_verify.close()
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error loading image: {str(e)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure stream is closed
|
||||||
|
if 'image_stream' in locals():
|
||||||
|
image_stream.close()
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
row = self.df.iloc[idx]
|
||||||
|
|
||||||
|
# Load images using the new method
|
||||||
|
low_res_image = self.load_image(row['image_512'])
|
||||||
|
high_res_image = self.load_image(row['image_1024'])
|
||||||
|
|
||||||
# Apply augmentation and normalization
|
# Apply augmentation and normalization
|
||||||
augmented = self.augmentation(image=low_res, mask=high_res)
|
augmented = self.augmentation(image=low_res_image, mask=high_res_image)
|
||||||
low_res = augmented['image']
|
low_res = augmented['image']
|
||||||
high_res = augmented['mask']
|
high_res = augmented['mask']
|
||||||
|
|
||||||
|
@ -51,8 +90,7 @@ class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
'high_res': high_res
|
'high_res': high_res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=2, epochs = 10):
|
||||||
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10):
|
|
||||||
# Initialize dataset and dataloader
|
# Initialize dataset and dataloader
|
||||||
train_dataset = aiuNNDataset(train_parquet_path)
|
train_dataset = aiuNNDataset(train_parquet_path)
|
||||||
val_dataset = aiuNNDataset(val_parquet_path)
|
val_dataset = aiuNNDataset(val_parquet_path)
|
||||||
|
|
Loading…
Reference in New Issue