import torch import pandas as pd from albumentations import ( Compose, Resize, Normalize, RandomBrightnessContrast, HorizontalFlip, VerticalFlip, Rotate, GaussianBlur ) from albumentations.pytorch import ToTensorV2 from PIL import Image, ImageFile import io import base64 import numpy as np from torch import nn from torch.utils.data import random_split from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive class aiuNNDataset(torch.utils.data.Dataset): def __init__(self, parquet_path): self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2500) self.augmentation = Compose([ RandomBrightnessContrast(p=0.5), HorizontalFlip(p=0.5), VerticalFlip(p=0.5), Rotate(limit=45, p=0.5), GaussianBlur(blur_limit=(3, 7), p=0.5), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2() ]) def __len__(self): return len(self.df) def load_image(self, image_data): try: if isinstance(image_data, str): image_data = base64.b64decode(image_data) if not isinstance(image_data, bytes): raise ValueError("Invalid image data format") image_stream = io.BytesIO(image_data) ImageFile.LOAD_TRUNCATED_IMAGES = True image = Image.open(image_stream).convert('RGB') image_array = np.array(image) return image_array except Exception as e: raise RuntimeError(f"Error loading image: {str(e)}") finally: if 'image_stream' in locals(): image_stream.close() def __getitem__(self, idx): row = self.df.iloc[idx] low_res_image = self.load_image(row['image_512']) high_res_image = self.load_image(row['image_1024']) augmented_low = self.augmentation(image=low_res_image) augmented_high = self.augmentation(image=high_res_image) return { 'low_res': augmented_low['image'], 'high_res': augmented_high['image'] } def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): loaded_datasets = [aiuNNDataset(d) for d in datasets] combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) train_size = int(0.8 * len(combined_dataset)) val_size = len(combined_dataset) - train_size train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size]) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) best_val_loss = float('inf') from tqdm import tqdm for epoch in range(epochs): model.train() train_loss = 0.0 for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"): if torch.cuda.is_available(): torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) optimizer.zero_grad() outputs = model(low_res) loss = criterion(outputs, high_res) loss.backward() optimizer.step() train_loss += loss.item() avg_train_loss = train_loss / len(train_loader) print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") model.eval() val_loss = 0.0 with torch.no_grad(): for batch in tqdm(val_loader, desc="Validation"): if torch.cuda.is_available(): torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) outputs = model(low_res) loss = criterion(outputs, high_res) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss torch.save(model.state_dict(), "best_model.pth") return model def main(): BATCH_SIZE = 2 model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") if hasattr(model, 'chunked_'): model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) finetune_model( model=model, datasets=[ "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" ], batch_size=BATCH_SIZE ) if __name__ == '__main__': main()