diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index e155e5e..f40296e 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -8,23 +8,22 @@ from albumentations.pytorch import ToTensorV2 from PIL import Image, ImageFile import io import base64 +import numpy as np from torch import nn -# Import the model and config from your existing code +from torch.utils.data import random_split from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive class aiuNNDataset(torch.utils.data.Dataset): def __init__(self, parquet_path): - # Read the Parquet file - self.df = pd.read_parquet(parquet_path).head(1250) + self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']) - # Data augmentation pipeline without Resize as it's redundant self.augmentation = Compose([ - RandomBrightnessContrast(), + RandomBrightnessContrast(p=0.5), HorizontalFlip(p=0.5), VerticalFlip(p=0.5), - Rotate(degrees=45), - GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), - Normalize(mean=[0.5], std=[0.5]), + Rotate(limit=45, p=0.5), + GaussianBlur(blur_limit=(3, 7), p=0.5), + Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2() ]) @@ -33,92 +32,66 @@ class aiuNNDataset(torch.utils.data.Dataset): def load_image(self, image_data): try: - # Handle both bytes and base64 encoded strings if isinstance(image_data, str): - # Decode base64 string to bytes image_data = base64.b64decode(image_data) - # Verify data is valid before creating BytesIO if not isinstance(image_data, bytes): raise ValueError("Invalid image data format") - # Create image stream image_stream = io.BytesIO(image_data) - - # Enable loading of truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True - # Load and convert image to RGB image = Image.open(image_stream).convert('RGB') + image_array = np.array(image) - # Create fresh copy for verify() since it modifies the image object - image_verify = image.copy() - - # Verify image is valid - try: - image_verify.verify() - except Exception as e: - raise ValueError(f"Image verification failed: {str(e)}") - finally: - image_verify.close() - - return image - + return image_array except Exception as e: raise RuntimeError(f"Error loading image: {str(e)}") - finally: - # Ensure stream is closed if 'image_stream' in locals(): image_stream.close() def __getitem__(self, idx): row = self.df.iloc[idx] - # Load images using the new method low_res_image = self.load_image(row['image_512']) high_res_image = self.load_image(row['image_1024']) - # Apply augmentation and normalization augmented_low = self.augmentation(image=low_res_image) - low_res = augmented_low['image'] - augmented_high = self.augmentation(image=high_res_image) - high_res = augmented_high['image'] - return { - 'low_res': low_res, - 'high_res': high_res + 'low_res': augmented_low['image'], + 'high_res': augmented_high['image'] } -from torch.utils.data.dataset import ConcatDataset - -def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10): - # Load all datasets and concatenate them +def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): loaded_datasets = [aiuNNDataset(d) for d in datasets] - combined_dataset = ConcatDataset(loaded_datasets) + combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) - # Split into training and validation sets - train_dataset, val_dataset = combined_dataset.train_val_split() + train_size = int(0.8 * len(combined_dataset)) + val_size = len(combined_dataset) - train_size + train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size]) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, - num_workers=4 + num_workers=4, + pin_memory=True, + persistent_workers=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, - num_workers=4 + num_workers=4, + pin_memory=True, + persistent_workers=True ) - # Set device - device = 'cuda' if torch.cuda.is_available() else 'cpu' - model.to(device) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) - # Define loss function and optimizer criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) @@ -128,73 +101,62 @@ def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10): for epoch in range(epochs): model.train() - train_loss = 0.0 - for batch_idx, batch in enumerate(tqdm(train_loader)): - # Your training code here + for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"): + if torch.cuda.is_available(): + torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) - # Forward pass - outputs = model(low_res) - - # Calculate loss - loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) # Adjust for channel dimensions - - # Backward pass and optimize optimizer.zero_grad() + outputs = model(low_res) + loss = criterion(outputs, high_res) + loss.backward() optimizer.step() - train_loss += loss.item() - - avg_train_loss = train_loss / len(train_loader) + avg_train_loss = train_loss / len(train_loader) print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") - # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for batch in tqdm(val_loader, desc="Validation"): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) outputs = model(low_res) - loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) - + loss = criterion(outputs, high_res) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) - print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") - - # Save best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss - model.save("best_model") + torch.save(model.state_dict(), "best_model.pth") return model def main(): - # Paths to your data - train_parquet_path = "/root/training_data/vision-dataset/image_upscaler.parquet" - val_parquet_path = "/root/training_data/vision-dataset/image_vec_upscaler.parquet" - - # Load pretrained model + BATCH_SIZE = 1 model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") - # Add final upsampling layer if needed (depending on your specific architecture) if hasattr(model, 'chunked_'): model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) - # Fine-tune finetune_model( - model, - train_parquet_path, - val_parquet_path + model=model, + datasets=[ + "/root/training_data/vision-dataset/image_upscaler.parquet", + "/root/training_data/vision-dataset/image_vec_upscaler.parquet" + ], + batch_size=BATCH_SIZE ) if __name__ == '__main__':