201 lines
7.4 KiB
Python
201 lines
7.4 KiB
Python
import torch
|
|
import pandas as pd
|
|
from albumentations import (
|
|
Compose, Resize, Normalize, RandomBrightnessContrast,
|
|
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
|
|
)
|
|
from albumentations.pytorch import ToTensorV2
|
|
from PIL import Image, ImageFile
|
|
import io
|
|
import base64
|
|
import numpy as np
|
|
from torch import nn
|
|
from torch.utils.data import random_split, DataLoader
|
|
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
|
from torch.amp import autocast, GradScaler
|
|
from tqdm import tqdm
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
class aiuNNDataset(torch.utils.data.Dataset):
|
|
def __init__(self, parquet_path):
|
|
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(10000)
|
|
self.augmentation = Compose([
|
|
RandomBrightnessContrast(p=0.5),
|
|
HorizontalFlip(p=0.5),
|
|
VerticalFlip(p=0.5),
|
|
Rotate(limit=45, p=0.5),
|
|
GaussianBlur(blur_limit=(3, 7), p=0.5),
|
|
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
ToTensorV2()
|
|
])
|
|
|
|
def __len__(self):
|
|
return len(self.df)
|
|
|
|
def load_image(self, image_data):
|
|
try:
|
|
if isinstance(image_data, str):
|
|
image_data = base64.b64decode(image_data)
|
|
if not isinstance(image_data, bytes):
|
|
raise ValueError("Invalid image data format")
|
|
image_stream = io.BytesIO(image_data)
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
image = Image.open(image_stream).convert('RGB')
|
|
image_array = np.array(image)
|
|
return image_array
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error loading image: {str(e)}")
|
|
finally:
|
|
if 'image_stream' in locals():
|
|
image_stream.close()
|
|
|
|
def __getitem__(self, idx):
|
|
row = self.df.iloc[idx]
|
|
low_res_image = self.load_image(row['image_512'])
|
|
high_res_image = self.load_image(row['image_1024'])
|
|
augmented_low = self.augmentation(image=low_res_image)
|
|
augmented_high = self.augmentation(image=high_res_image)
|
|
return {
|
|
'low_res': augmented_low['image'],
|
|
'high_res': augmented_high['image']
|
|
}
|
|
|
|
class Upscaler(nn.Module):
|
|
"""
|
|
Transforms the base model's final feature map using a transposed convolution.
|
|
The base model produces a feature map of size 512x512.
|
|
This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features
|
|
to the output channels using a single ConvTranspose2d layer.
|
|
"""
|
|
def __init__(self, base_model: AIIABase):
|
|
super(Upscaler, self).__init__()
|
|
self.base_model = base_model
|
|
# Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer.
|
|
self.last_transform = nn.ConvTranspose2d(
|
|
in_channels=base_model.config.hidden_size,
|
|
out_channels=base_model.config.num_channels,
|
|
kernel_size=base_model.config.kernel_size,
|
|
stride=2,
|
|
padding=1,
|
|
output_padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
features = self.base_model(x)
|
|
return self.last_transform(features)
|
|
|
|
def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False):
|
|
# Load and concatenate datasets.
|
|
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
|
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
|
|
train_size = int(0.8 * len(combined_dataset))
|
|
val_size = len(combined_dataset) - train_size
|
|
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
pin_memory=True,
|
|
persistent_workers=True
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=4,
|
|
pin_memory=True,
|
|
persistent_workers=True
|
|
)
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
if device.type == 'cuda':
|
|
current_device = torch.cuda.current_device()
|
|
torch.cuda.set_per_process_memory_fraction(0.95, device=current_device)
|
|
|
|
model = model.to(device)
|
|
criterion = nn.MSELoss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=model.base_model.config.learning_rate)
|
|
scaler = GradScaler()
|
|
best_val_loss = float('inf')
|
|
|
|
for epoch in range(epochs):
|
|
model.train()
|
|
train_loss = 0.0
|
|
optimizer.zero_grad()
|
|
for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1):
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
low_res = batch['low_res'].to(device)
|
|
high_res = batch['high_res'].to(device)
|
|
with autocast(device_type="cuda"):
|
|
if use_checkpoint:
|
|
low_res = batch['low_res'].to(device).requires_grad_()
|
|
features = checkpoint(lambda x: model(x), low_res)
|
|
else:
|
|
features = model(low_res)
|
|
loss = criterion(features, high_res) / accumulation_steps
|
|
scaler.scale(loss).backward()
|
|
train_loss += loss.item() * accumulation_steps
|
|
if i % accumulation_steps == 0:
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
if (i % accumulation_steps) != 0:
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
|
|
avg_train_loss = train_loss / len(train_loader)
|
|
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
|
|
|
|
model.eval()
|
|
val_loss = 0.0
|
|
with torch.no_grad():
|
|
for batch in tqdm(val_loader, desc="Validation"):
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
low_res = batch['low_res'].to(device)
|
|
high_res = batch['high_res'].to(device)
|
|
with autocast(device_type="cuda"):
|
|
outputs = model(low_res)
|
|
loss = criterion(outputs, high_res)
|
|
val_loss += loss.item()
|
|
avg_val_loss = val_loss / len(val_loader)
|
|
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
|
if avg_val_loss < best_val_loss:
|
|
best_val_loss = avg_val_loss
|
|
model.base_model.save("best_model")
|
|
return model
|
|
|
|
def main():
|
|
BATCH_SIZE = 1
|
|
ACCUMULATION_STEPS = 8
|
|
USE_CHECKPOINT = False
|
|
|
|
# Load the base model using the provided configuration (e.g., hidden_size=512, num_channels=3, etc.)
|
|
base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
|
|
|
|
# Wrap the base model with our modified Upscaler that transforms its last layer.
|
|
model = Upscaler(base_model)
|
|
|
|
print("Modified model architecture with transformed final layer:")
|
|
print(base_model.config)
|
|
|
|
finetune_model(
|
|
model=model,
|
|
datasets=[
|
|
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
|
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
|
],
|
|
batch_size=BATCH_SIZE,
|
|
epochs=10,
|
|
accumulation_steps=ACCUMULATION_STEPS,
|
|
use_checkpoint=USE_CHECKPOINT
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|