import torch import torch.nn as nn import torch.optim as optim import pandas as pd import io import csv import base64 from PIL import Image, ImageFile from torch.amp import autocast, GradScaler from torch.utils.data import Dataset, DataLoader from torchvision import transforms from tqdm import tqdm from torch.utils.checkpoint import checkpoint import gc from aiia import AIIABase from upsampler import Upsampler # Define a simple EarlyStopping class to monitor the epoch loss. class EarlyStopping: def __init__(self, patience=3, min_delta=0.001): self.patience = patience # Number of epochs with no significant improvement before stopping. self.min_delta = min_delta # Minimum change in loss required to count as an improvement. self.best_loss = float('inf') self.counter = 0 self.early_stop = False def __call__(self, epoch_loss): if epoch_loss < self.best_loss - self.min_delta: self.best_loss = epoch_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True return self.early_stop # UpscaleDataset to load and preprocess your data. class UpscaleDataset(Dataset): def __init__(self, parquet_files: list, transform=None): combined_df = pd.DataFrame() for parquet_file in parquet_files: df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(500) combined_df = pd.concat([combined_df, df], ignore_index=True) self.df = combined_df.apply(self._validate_row, axis=1) self.transform = transform self.failed_indices = set() def _validate_row(self, row): for col in ['image_512', 'image_1024']: if not isinstance(row[col], (bytes, str)): raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") return row def _decode_image(self, data): try: if isinstance(data, str): return base64.b64decode(data) elif isinstance(data, bytes): return data raise ValueError(f"Unsupported data type: {type(data)}") except Exception as e: raise RuntimeError(f"Decoding failed: {str(e)}") def __len__(self): return len(self.df) def __getitem__(self, idx): if idx in self.failed_indices: return self[(idx + 1) % len(self)] try: row = self.df.iloc[idx] low_res_bytes = self._decode_image(row['image_512']) high_res_bytes = self._decode_image(row['image_1024']) ImageFile.LOAD_TRUNCATED_IMAGES = True low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') if low_res.size != (512, 512) or high_res.size != (1024, 1024): raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}") if self.transform: low_res = self.transform(low_res) high_res = self.transform(high_res) return low_res, high_res except Exception as e: print(f"\nError at index {idx}: {str(e)}") self.failed_indices.add(idx) return self[(idx + 1) % len(self)] # Define any transformations you require. transform = transforms.Compose([ transforms.ToTensor(), ]) # Load the base AIIABase model and wrap it with the Upsampler. pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" base_model = AIIABase.load(pretrained_model_path) model = Upsampler(base_model) device = torch.device("cpu")#torch.device("cuda" if torch.cuda.is_available() else "cpu") # Move model to device using channels_last memory format. model = model.to(device, memory_format=torch.channels_last) # Optional: flag to enable gradient checkpointing. use_checkpointing = True # Create the dataset and dataloader. dataset = UpscaleDataset([ "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" ], transform=transform) data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Consider adjusting num_workers if needed. # Define loss function and optimizer. criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) num_epochs = 10 model.train() # Prepare a CSV file for logging training loss. csv_file = 'losses.csv' with open(csv_file, mode='a', newline='') as file: writer = csv.writer(file) if file.tell() == 0: writer.writerow(['Epoch', 'Train Loss']) scaler = GradScaler() early_stopping = EarlyStopping(patience=3, min_delta=0.001) # Training loop with early stopping. for epoch in range(num_epochs): epoch_loss = 0.0 progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}") print(f"Epoch: {epoch + 1}") for low_res, high_res in progress_bar: # Move data to GPU with channels_last format where possible. low_res = low_res.to(device, non_blocking=True).to(memory_format=torch.channels_last) high_res = high_res.to(device, non_blocking=True) optimizer.zero_grad() with autocast(device_type=device.type): if use_checkpointing: # Ensure the input tensor requires gradient so that checkpointing records the computation graph. low_res.requires_grad_() outputs = checkpoint(model, low_res) else: outputs = model(low_res) loss = criterion(outputs, high_res) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() progress_bar.set_postfix({'loss': loss.item()}) # Optionally delete variables to free memory. del low_res, high_res, outputs, loss # Perform garbage collection and clear GPU cache after each epoch. gc.collect() torch.cuda.empty_cache() print(f"Epoch {epoch + 1}, Loss: {epoch_loss}") # Record the loss in the CSV log. with open(csv_file, mode='a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch + 1, epoch_loss]) if early_stopping(epoch_loss): print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}") break # Optionally save the fine-tuned model. finetuned_model_path = "aiuNN" model.save(finetuned_model_path)