finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 62 additions and 189 deletions
Showing only changes of commit 0c0372794e - Show all commits

View File

@ -1,200 +1,73 @@
import torch
import pandas as pd
from albumentations import (
Compose, Resize, Normalize, RandomBrightnessContrast,
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
)
from albumentations.pytorch import ToTensorV2
from PIL import Image, ImageFile
import io
import base64
import numpy as np
from torch import nn
from torch.utils.data import random_split, DataLoader
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
from torch.amp import autocast, GradScaler
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from aiia import AIIA
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path):
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(10000)
self.augmentation = Compose([
RandomBrightnessContrast(p=0.5),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
Rotate(limit=45, p=0.5),
GaussianBlur(blur_limit=(3, 7), p=0.5),
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ToTensorV2()
])
class UpscaleDataset(Dataset):
def __init__(self, parquet_file, transform=None):
self.df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000)
self.transform = transform
def __len__(self):
return len(self.df)
def load_image(self, image_data):
try:
if isinstance(image_data, str):
image_data = base64.b64decode(image_data)
if not isinstance(image_data, bytes):
raise ValueError("Invalid image data format")
image_stream = io.BytesIO(image_data)
ImageFile.LOAD_TRUNCATED_IMAGES = True
image = Image.open(image_stream).convert('RGB')
image_array = np.array(image)
return image_array
except Exception as e:
raise RuntimeError(f"Error loading image: {str(e)}")
finally:
if 'image_stream' in locals():
image_stream.close()
def __getitem__(self, idx):
row = self.df.iloc[idx]
low_res_image = self.load_image(row['image_512'])
high_res_image = self.load_image(row['image_1024'])
augmented_low = self.augmentation(image=low_res_image)
augmented_high = self.augmentation(image=high_res_image)
return {
'low_res': augmented_low['image'],
'high_res': augmented_high['image']
}
# Decode the byte strings into images
low_res_bytes = row['image_512']
high_res_bytes = row['image_1024']
low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return low_res_image, high_res_image
class Upscaler(nn.Module):
"""
Transforms the base model's final feature map using a transposed convolution.
The base model produces a feature map of size 512x512.
This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features
to the output channels using a single ConvTranspose2d layer.
"""
def __init__(self, base_model: AIIABase):
super(Upscaler, self).__init__()
self.base_model = base_model
# Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer.
self.last_transform = nn.ConvTranspose2d(
in_channels=base_model.config.hidden_size,
out_channels=base_model.config.num_channels,
kernel_size=base_model.config.kernel_size,
stride=2,
padding=1,
output_padding=1
)
# Example transform: converting PIL images to tensors
transform = transforms.Compose([
transforms.ToTensor(),
])
def forward(self, x):
features = self.base_model(x)
return self.last_transform(features)
def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False):
# Load and concatenate datasets.
loaded_datasets = [aiuNNDataset(d) for d in datasets]
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
train_size = int(0.8 * len(combined_dataset))
val_size = len(combined_dataset) - train_size
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
import torch
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
persistent_workers=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
persistent_workers=True
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
current_device = torch.cuda.current_device()
torch.cuda.set_per_process_memory_fraction(0.95, device=current_device)
# Replace with your actual pretrained model path
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
# Load the model using the AIIA.load class method (the implementation copied in your query)
model = AIIA.load(pretrained_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
from torch import nn, optim
from torch.utils.data import DataLoader
# Create your dataset and dataloader
dataset = UpscaleDataset("/root/training_data/vision-dataset/image_upscaler.parquet", transform=transform)
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
# Define a loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.base_model.config.learning_rate)
scaler = GradScaler()
best_val_loss = float('inf')
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(epochs):
model.train()
train_loss = 0.0
optimizer.zero_grad()
for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1):
if torch.cuda.is_available():
torch.cuda.empty_cache()
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
with autocast(device_type="cuda"):
if use_checkpoint:
low_res = batch['low_res'].to(device).requires_grad_()
features = checkpoint(lambda x: model(x), low_res)
else:
features = model(low_res)
loss = criterion(features, high_res) / accumulation_steps
scaler.scale(loss).backward()
train_loss += loss.item() * accumulation_steps
if i % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if (i % accumulation_steps) != 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
num_epochs = 10
model.train() # Set model in training mode
avg_train_loss = train_loss / len(train_loader)
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
for epoch in range(num_epochs):
epoch_loss = 0.0
for low_res, high_res in data_loader:
low_res = low_res.to(device)
high_res = high_res.to(device)
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in tqdm(val_loader, desc="Validation"):
if torch.cuda.is_available():
torch.cuda.empty_cache()
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
with autocast(device_type="cuda"):
optimizer.zero_grad()
outputs = model(low_res)
loss = criterion(outputs, high_res)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.base_model.save("best_model")
return model
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(data_loader)}")
def main():
BATCH_SIZE = 1
ACCUMULATION_STEPS = 8
USE_CHECKPOINT = False
# Load the base model using the provided configuration (e.g., hidden_size=512, num_channels=3, etc.)
base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
# Wrap the base model with our modified Upscaler that transforms its last layer.
model = Upscaler(base_model)
print("Modified model architecture with transformed final layer:")
print(base_model.config)
finetune_model(
model=model,
datasets=[
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
],
batch_size=BATCH_SIZE,
epochs=10,
accumulation_steps=ACCUMULATION_STEPS,
use_checkpoint=USE_CHECKPOINT
)
if __name__ == '__main__':
main()
# Optionally, save the finetuned model to a new directory
finetuned_model_path = "aiuNN"
model.save(finetuned_model_path)