import torch import torch.nn as nn import torch.optim as optim import pandas as pd import io import csv import base64 from PIL import Image, ImageFile from torch.amp import autocast, GradScaler from torch.utils.data import Dataset, DataLoader from torchvision import transforms from tqdm import tqdm from aiia import AIIABase from aiunn.upsampler import Upsampler # Define a simple EarlyStopping class to monitor the epoch loss. class EarlyStopping: def __init__(self, patience=3, min_delta=0.001): self.patience = patience # Number of epochs with no significant improvement before stopping. self.min_delta = min_delta # Minimum change in loss required to count as an improvement. self.best_loss = float('inf') self.counter = 0 self.early_stop = False def __call__(self, epoch_loss): # If current loss is lower than the best loss minus min_delta, update best loss and reset counter. if epoch_loss < self.best_loss - self.min_delta: self.best_loss = epoch_loss self.counter = 0 else: # No significant improvement: increment counter. self.counter += 1 if self.counter >= self.patience: self.early_stop = True return self.early_stop # UpscaleDataset to load and preprocess your data. class UpscaleDataset(Dataset): def __init__(self, parquet_files: list, transform=None): combined_df = pd.DataFrame() for parquet_file in parquet_files: # Load data with head() to limit rows for memory efficiency. df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(1250) combined_df = pd.concat([combined_df, df], ignore_index=True) # Validate that each row has proper image formats. self.df = combined_df.apply(self._validate_row, axis=1) self.transform = transform self.failed_indices = set() def _validate_row(self, row): for col in ['image_512', 'image_1024']: if not isinstance(row[col], (bytes, str)): raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") return row def _decode_image(self, data): try: if isinstance(data, str): return base64.b64decode(data) elif isinstance(data, bytes): return data raise ValueError(f"Unsupported data type: {type(data)}") except Exception as e: raise RuntimeError(f"Decoding failed: {str(e)}") def __len__(self): return len(self.df) def __getitem__(self, idx): # Skip indices that have previously failed. if idx in self.failed_indices: return self[(idx + 1) % len(self)] try: row = self.df.iloc[idx] low_res_bytes = self._decode_image(row['image_512']) high_res_bytes = self._decode_image(row['image_1024']) ImageFile.LOAD_TRUNCATED_IMAGES = True low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') # Validate expected sizes if low_res.size != (512, 512) or high_res.size != (1024, 1024): raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}") if self.transform: low_res = self.transform(low_res) high_res = self.transform(high_res) return low_res, high_res except Exception as e: print(f"\nError at index {idx}: {str(e)}") self.failed_indices.add(idx) return self[(idx + 1) % len(self)] # Define any transformations you require (e.g., converting PIL images to tensors) transform = transforms.Compose([ transforms.ToTensor(), ]) # Load the base AIIABase model and wrap it with the Upsampler. pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" base_model = AIIABase.load(pretrained_model_path) model = Upsampler(base_model) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Create the dataset and dataloader. dataset = UpscaleDataset([ "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" ], transform=transform) data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Define loss function and optimizer. criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) num_epochs = 10 model.train() # Prepare a CSV file for logging training loss. csv_file = 'losses.csv' with open(csv_file, mode='a', newline='') as file: writer = csv.writer(file) if file.tell() == 0: writer.writerow(['Epoch', 'Train Loss']) # Initialize automatic mixed precision scaler and EarlyStopping. scaler = GradScaler() early_stopping = EarlyStopping(patience=3, min_delta=0.001) # Training loop with early stopping. for epoch in range(num_epochs): epoch_loss = 0.0 progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}") print(f"Epoch: {epoch}") for low_res, high_res in progress_bar: low_res = low_res.to(device, non_blocking=True) high_res = high_res.to(device, non_blocking=True) optimizer.zero_grad() # Use automatic mixed precision to speed up training on supported hardware. with autocast(device_type=device.type): outputs = model(low_res) loss = criterion(outputs, high_res) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() progress_bar.set_postfix({'loss': loss.item()}) print(f"Epoch {epoch + 1}, Loss: {epoch_loss}") # Record the loss in the CSV log. with open(csv_file, mode='a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch + 1, epoch_loss]) # Check early stopping criteria. if early_stopping(epoch_loss): print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}") break # Optionally, save the finetuned model using your library's save method. finetuned_model_path = "aiuNN" model.save(finetuned_model_path)