diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index b0dd634..23211da 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -1,200 +1,73 @@ -import torch import pandas as pd -from albumentations import ( - Compose, Resize, Normalize, RandomBrightnessContrast, - HorizontalFlip, VerticalFlip, Rotate, GaussianBlur -) -from albumentations.pytorch import ToTensorV2 -from PIL import Image, ImageFile import io -import base64 -import numpy as np -from torch import nn -from torch.utils.data import random_split, DataLoader -from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive -from torch.amp import autocast, GradScaler -from tqdm import tqdm -from torch.utils.checkpoint import checkpoint +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from aiia import AIIA + +class UpscaleDataset(Dataset): + def __init__(self, parquet_file, transform=None): + self.df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000) + self.transform = transform -class aiuNNDataset(torch.utils.data.Dataset): - def __init__(self, parquet_path): - self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(10000) - self.augmentation = Compose([ - RandomBrightnessContrast(p=0.5), - HorizontalFlip(p=0.5), - VerticalFlip(p=0.5), - Rotate(limit=45, p=0.5), - GaussianBlur(blur_limit=(3, 7), p=0.5), - Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ToTensorV2() - ]) - def __len__(self): return len(self.df) - - def load_image(self, image_data): - try: - if isinstance(image_data, str): - image_data = base64.b64decode(image_data) - if not isinstance(image_data, bytes): - raise ValueError("Invalid image data format") - image_stream = io.BytesIO(image_data) - ImageFile.LOAD_TRUNCATED_IMAGES = True - image = Image.open(image_stream).convert('RGB') - image_array = np.array(image) - return image_array - except Exception as e: - raise RuntimeError(f"Error loading image: {str(e)}") - finally: - if 'image_stream' in locals(): - image_stream.close() - + def __getitem__(self, idx): row = self.df.iloc[idx] - low_res_image = self.load_image(row['image_512']) - high_res_image = self.load_image(row['image_1024']) - augmented_low = self.augmentation(image=low_res_image) - augmented_high = self.augmentation(image=high_res_image) - return { - 'low_res': augmented_low['image'], - 'high_res': augmented_high['image'] - } + # Decode the byte strings into images + low_res_bytes = row['image_512'] + high_res_bytes = row['image_1024'] + low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') + high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') + if self.transform: + low_res_image = self.transform(low_res_image) + high_res_image = self.transform(high_res_image) + return low_res_image, high_res_image -class Upscaler(nn.Module): - """ - Transforms the base model's final feature map using a transposed convolution. - The base model produces a feature map of size 512x512. - This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features - to the output channels using a single ConvTranspose2d layer. - """ - def __init__(self, base_model: AIIABase): - super(Upscaler, self).__init__() - self.base_model = base_model - # Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer. - self.last_transform = nn.ConvTranspose2d( - in_channels=base_model.config.hidden_size, - out_channels=base_model.config.num_channels, - kernel_size=base_model.config.kernel_size, - stride=2, - padding=1, - output_padding=1 - ) - - def forward(self, x): - features = self.base_model(x) - return self.last_transform(features) +# Example transform: converting PIL images to tensors +transform = transforms.Compose([ + transforms.ToTensor(), +]) + + +import torch + +# Replace with your actual pretrained model path +pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" + +# Load the model using the AIIA.load class method (the implementation copied in your query) +model = AIIA.load(pretrained_model_path) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = model.to(device) +from torch import nn, optim +from torch.utils.data import DataLoader + +# Create your dataset and dataloader +dataset = UpscaleDataset("/root/training_data/vision-dataset/image_upscaler.parquet", transform=transform) +data_loader = DataLoader(dataset, batch_size=16, shuffle=True) + +# Define a loss function and optimizer +criterion = nn.MSELoss() +optimizer = optim.Adam(model.parameters(), lr=1e-4) + +num_epochs = 10 +model.train() # Set model in training mode + +for epoch in range(num_epochs): + epoch_loss = 0.0 + for low_res, high_res in data_loader: + low_res = low_res.to(device) + high_res = high_res.to(device) -def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False): - # Load and concatenate datasets. - loaded_datasets = [aiuNNDataset(d) for d in datasets] - combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) - train_size = int(0.8 * len(combined_dataset)) - val_size = len(combined_dataset) - train_size - train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size]) - - train_loader = DataLoader( - train_dataset, - batch_size=batch_size, - shuffle=True, - num_workers=4, - pin_memory=True, - persistent_workers=True - ) - - val_loader = DataLoader( - val_dataset, - batch_size=batch_size, - shuffle=False, - num_workers=4, - pin_memory=True, - persistent_workers=True - ) - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if device.type == 'cuda': - current_device = torch.cuda.current_device() - torch.cuda.set_per_process_memory_fraction(0.95, device=current_device) - - model = model.to(device) - criterion = nn.MSELoss() - optimizer = torch.optim.Adam(model.parameters(), lr=model.base_model.config.learning_rate) - scaler = GradScaler() - best_val_loss = float('inf') - - for epoch in range(epochs): - model.train() - train_loss = 0.0 optimizer.zero_grad() - for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - low_res = batch['low_res'].to(device) - high_res = batch['high_res'].to(device) - with autocast(device_type="cuda"): - if use_checkpoint: - low_res = batch['low_res'].to(device).requires_grad_() - features = checkpoint(lambda x: model(x), low_res) - else: - features = model(low_res) - loss = criterion(features, high_res) / accumulation_steps - scaler.scale(loss).backward() - train_loss += loss.item() * accumulation_steps - if i % accumulation_steps == 0: - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - if (i % accumulation_steps) != 0: - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - avg_train_loss = train_loss / len(train_loader) - print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") - - model.eval() - val_loss = 0.0 - with torch.no_grad(): - for batch in tqdm(val_loader, desc="Validation"): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - low_res = batch['low_res'].to(device) - high_res = batch['high_res'].to(device) - with autocast(device_type="cuda"): - outputs = model(low_res) - loss = criterion(outputs, high_res) - val_loss += loss.item() - avg_val_loss = val_loss / len(val_loader) - print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") - if avg_val_loss < best_val_loss: - best_val_loss = avg_val_loss - model.base_model.save("best_model") - return model - -def main(): - BATCH_SIZE = 1 - ACCUMULATION_STEPS = 8 - USE_CHECKPOINT = False - - # Load the base model using the provided configuration (e.g., hidden_size=512, num_channels=3, etc.) - base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") + outputs = model(low_res) + loss = criterion(outputs, high_res) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(data_loader)}") - # Wrap the base model with our modified Upscaler that transforms its last layer. - model = Upscaler(base_model) - - print("Modified model architecture with transformed final layer:") - print(base_model.config) - - finetune_model( - model=model, - datasets=[ - "/root/training_data/vision-dataset/image_upscaler.parquet", - "/root/training_data/vision-dataset/image_vec_upscaler.parquet" - ], - batch_size=BATCH_SIZE, - epochs=10, - accumulation_steps=ACCUMULATION_STEPS, - use_checkpoint=USE_CHECKPOINT - ) - -if __name__ == '__main__': - main() +# Optionally, save the finetuned model to a new directory +finetuned_model_path = "aiuNN" +model.save(finetuned_model_path)