import pandas as pd import io from PIL import Image, ImageFile from torch.utils.data import Dataset from torchvision import transforms from aiia import AIIABase import csv from tqdm import tqdm import base64 from torch.amp import autocast, GradScaler import torch class UpscaleDataset(Dataset): def __init__(self, parquet_files: list, transform=None): combined_df = pd.DataFrame() for parquet_file in parquet_files: # Load data with chunking for memory efficiency df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(5000) combined_df = pd.concat([combined_df, df], ignore_index=True) # Validate data format self.df = combined_df.apply(self._validate_row, axis=1) self.transform = transform self.failed_indices = set() def _validate_row(self, row): """Ensure both images exist and have correct dimensions""" for col in ['image_512', 'image_1024']: if not isinstance(row[col], (bytes, str)): raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") return row def _decode_image(self, data): """Universal decoder handling both base64 strings and bytes""" try: if isinstance(data, str): # Handle base64 encoded strings return base64.b64decode(data) elif isinstance(data, bytes): return data raise ValueError(f"Unsupported data type: {type(data)}") except Exception as e: raise RuntimeError(f"Decoding failed: {str(e)}") def __len__(self): return len(self.df) def __getitem__(self, idx): if idx in self.failed_indices: return self[(idx + 1) % len(self)] # Skip failed indices try: row = self.df.iloc[idx] # Decode both images low_res_bytes = self._decode_image(row['image_512']) high_res_bytes = self._decode_image(row['image_1024']) # Load images with truncation handling ImageFile.LOAD_TRUNCATED_IMAGES = True low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') # Validate image sizes if low_res.size != (512, 512) or high_res.size != (1024, 1024): raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}") if self.transform: low_res = self.transform(low_res) high_res = self.transform(high_res) return low_res, high_res except Exception as e: print(f"\nError at index {idx}: {str(e)}") self.failed_indices.add(idx) return self[(idx + 1) % len(self)] # Return next valid sample # Example transform: converting PIL images to tensors transform = transforms.Compose([ transforms.ToTensor(), ]) # Replace with your actual pretrained model path pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" # Load the model using the AIIA.load class method (the implementation copied in your query) model = AIIABase.load(pretrained_model_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) from torch import nn, optim from torch.utils.data import DataLoader # Create your dataset and dataloader dataset = UpscaleDataset(["/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet"], transform=transform) data_loader = DataLoader(dataset, batch_size=2, shuffle=True) # Define a loss function and optimizer criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) num_epochs = 10 model.train() # Set model in training mode csv_file = 'losses.csv' # Create or open the CSV file and write the header if it doesn't exist with open(csv_file, mode='a', newline='') as file: writer = csv.writer(file) # Write the header only if the file is empty if file.tell() == 0: writer.writerow(['Epoch', 'Train Loss']) # Create a gradient scaler (for scaling gradients when using AMP) scaler = GradScaler() for epoch in range(num_epochs): epoch_loss = 0.0 data_loader_with_progress = tqdm(data_loader, desc=f"Epoch {epoch + 1}") for low_res, high_res in data_loader_with_progress: low_res = low_res.to(device, non_blocking=True) high_res = high_res.to(device, non_blocking=True) optimizer.zero_grad() # Use automatic mixed precision context with autocast(device_type=device): outputs = model(low_res) loss = criterion(outputs, high_res) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() print(f"Epoch {epoch + 1}, Loss: {epoch_loss}") # Append the training loss to the CSV file with open(csv_file, mode='a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch + 1, epoch_loss]) # Optionally, save the finetuned model to a new directory finetuned_model_path = "aiuNN" model.save(finetuned_model_path)