import torch import pandas as pd from albumentations import ( Compose, Resize, Normalize, RandomBrightnessContrast, HorizontalFlip, VerticalFlip, Rotate, GaussianBlur ) from albumentations.pytorch import ToTensorV2 from PIL import Image, ImageFile import io import base64 import numpy as np from torch import nn from torch.utils.data import random_split, DataLoader from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from torch.amp import autocast, GradScaler from tqdm import tqdm from torch.utils.checkpoint import checkpoint class aiuNNDataset(torch.utils.data.Dataset): def __init__(self, parquet_path): self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(10000) self.augmentation = Compose([ RandomBrightnessContrast(p=0.5), HorizontalFlip(p=0.5), VerticalFlip(p=0.5), Rotate(limit=45, p=0.5), GaussianBlur(blur_limit=(3, 7), p=0.5), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2() ]) def __len__(self): return len(self.df) def load_image(self, image_data): try: if isinstance(image_data, str): image_data = base64.b64decode(image_data) if not isinstance(image_data, bytes): raise ValueError("Invalid image data format") image_stream = io.BytesIO(image_data) ImageFile.LOAD_TRUNCATED_IMAGES = True image = Image.open(image_stream).convert('RGB') image_array = np.array(image) return image_array except Exception as e: raise RuntimeError(f"Error loading image: {str(e)}") finally: if 'image_stream' in locals(): image_stream.close() def __getitem__(self, idx): row = self.df.iloc[idx] low_res_image = self.load_image(row['image_512']) high_res_image = self.load_image(row['image_1024']) augmented_low = self.augmentation(image=low_res_image) augmented_high = self.augmentation(image=high_res_image) return { 'low_res': augmented_low['image'], 'high_res': augmented_high['image'] } class Upscaler(nn.Module): """ Transforms the base model's final feature map using a transposed convolution. The base model produces a feature map of size 512x512. This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features to the output channels using a single ConvTranspose2d layer. """ def __init__(self, base_model: AIIABase): super(Upscaler, self).__init__() self.base_model = base_model # Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer. self.last_transform = nn.ConvTranspose2d( in_channels=base_model.config.hidden_size, out_channels=base_model.config.num_channels, kernel_size=base_model.config.kernel_size, stride=2, padding=1, output_padding=1 ) def forward(self, x): features = self.base_model(x) return self.last_transform(features) def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False): # Load and concatenate datasets. loaded_datasets = [aiuNNDataset(d) for d in datasets] combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) train_size = int(0.8 * len(combined_dataset)) val_size = len(combined_dataset) - train_size train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size]) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type == 'cuda': current_device = torch.cuda.current_device() torch.cuda.set_per_process_memory_fraction(0.95, device=current_device) model = model.to(device) criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=model.base_model.config.learning_rate) scaler = GradScaler() best_val_loss = float('inf') for epoch in range(epochs): model.train() train_loss = 0.0 optimizer.zero_grad() for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1): if torch.cuda.is_available(): torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) with autocast(device_type="cuda"): if use_checkpoint: low_res = batch['low_res'].to(device).requires_grad_() features = checkpoint(lambda x: model(x), low_res) else: features = model(low_res) loss = criterion(features, high_res) / accumulation_steps scaler.scale(loss).backward() train_loss += loss.item() * accumulation_steps if i % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() if (i % accumulation_steps) != 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() avg_train_loss = train_loss / len(train_loader) print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") model.eval() val_loss = 0.0 with torch.no_grad(): for batch in tqdm(val_loader, desc="Validation"): if torch.cuda.is_available(): torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) with autocast(device_type="cuda"): outputs = model(low_res) loss = criterion(outputs, high_res) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss model.base_model.save("best_model") return model def main(): BATCH_SIZE = 1 ACCUMULATION_STEPS = 8 USE_CHECKPOINT = False # Load the base model using the provided configuration (e.g., hidden_size=512, num_channels=3, etc.) base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") # Wrap the base model with our modified Upscaler that transforms its last layer. model = Upscaler(base_model) print("Modified model architecture with transformed final layer:") print(base_model.config) finetune_model( model=model, datasets=[ "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" ], batch_size=BATCH_SIZE, epochs=10, accumulation_steps=ACCUMULATION_STEPS, use_checkpoint=USE_CHECKPOINT ) if __name__ == '__main__': main()