import torch import pandas as pd from albumentations import ( Compose, Resize, Normalize, RandomBrightnessContrast, HorizontalFlip, VerticalFlip, Rotate, GaussianBlur ) from albumentations.pytorch import ToTensorV2 from PIL import Image, ImageFile import io import base64 from torch import nn # Import the model and config from your existing code from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive class aiuNNDataset(torch.utils.data.Dataset): def __init__(self, parquet_path): # Read the Parquet file self.df = pd.read_parquet(parquet_path) # Data augmentation pipeline without Resize as it's redundant self.augmentation = Compose([ RandomBrightnessContrast(), HorizontalFlip(p=0.5), VerticalFlip(p=0.5), Rotate(degrees=45), GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), Normalize(mean=[0.5], std=[0.5]), ToTensorV2() ]) def __len__(self): return len(self.df) def load_image(self, image_data): try: # Handle both bytes and base64 encoded strings if isinstance(image_data, str): # Decode base64 string to bytes image_data = base64.b64decode(image_data) # Verify data is valid before creating BytesIO if not isinstance(image_data, bytes): raise ValueError("Invalid image data format") # Create image stream image_stream = io.BytesIO(image_data) # Enable loading of truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True # Load and convert image to RGB image = Image.open(image_stream).convert('RGB') # Create fresh copy for verify() since it modifies the image object image_verify = image.copy() # Verify image is valid try: image_verify.verify() except Exception as e: raise ValueError(f"Image verification failed: {str(e)}") finally: image_verify.close() return image except Exception as e: raise RuntimeError(f"Error loading image: {str(e)}") finally: # Ensure stream is closed if 'image_stream' in locals(): image_stream.close() def __getitem__(self, idx): row = self.df.iloc[idx] # Load images using the new method low_res_image = self.load_image(row['image_512']) high_res_image = self.load_image(row['image_1024']) # Apply augmentation and normalization augmented_low = self.augmentation(image=low_res_image) low_res = augmented_low['image'] augmented_high = self.augmentation(image=high_res_image) high_res = augmented_high['image'] return { 'low_res': low_res, 'high_res': high_res } from torch.utils.data.dataset import ConcatDataset def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10): # Load all datasets and concatenate them loaded_datasets = [aiuNNDataset(d) for d in datasets] combined_dataset = ConcatDataset(loaded_datasets) # Split into training and validation sets train_dataset, val_dataset = combined_dataset.train_val_split() train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=4 ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) # Set device device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) # Define loss function and optimizer criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) best_val_loss = float('inf') for epoch in range(epochs): model.train() train_loss = 0.0 for batch_idx, batch in enumerate(train_loader): low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) # Forward pass outputs = model(low_res) # Calculate loss loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) # Adjust for channel dimensions # Backward pass and optimize optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() avg_train_loss = train_loss / len(train_loader) print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for batch in val_loader: low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) outputs = model(low_res) loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") # Save best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss model.save("best_model") return model def main(): # Paths to your data train_parquet_path = "/root/training_data/vision-dataset/image_upscaler.parquet" val_parquet_path = "/root/training_data/vision-dataset/image_vec_upscaler.parquet" # Load pretrained model model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") # Add final upsampling layer if needed (depending on your specific architecture) if hasattr(model, 'chunked_'): model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) # Fine-tune finetune_model( model, train_parquet_path, val_parquet_path ) if __name__ == '__main__': main()