diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 84f6b69..817e080 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -12,13 +12,12 @@ import numpy as np from torch import nn from torch.utils.data import random_split, DataLoader from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive -from torch.amp import autocast, GradScaler +from torch.cuda.amp import autocast, GradScaler from tqdm import tqdm class aiuNNDataset(torch.utils.data.Dataset): def __init__(self, parquet_path): self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2000) - self.augmentation = Compose([ RandomBrightnessContrast(p=0.5), HorizontalFlip(p=0.5), @@ -28,37 +27,31 @@ class aiuNNDataset(torch.utils.data.Dataset): Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2() ]) - + def __len__(self): return len(self.df) - + def load_image(self, image_data): try: if isinstance(image_data, str): image_data = base64.b64decode(image_data) - if not isinstance(image_data, bytes): raise ValueError("Invalid image data format") - image_stream = io.BytesIO(image_data) ImageFile.LOAD_TRUNCATED_IMAGES = True - image = Image.open(image_stream).convert('RGB') image_array = np.array(image) - return image_array except Exception as e: raise RuntimeError(f"Error loading image: {str(e)}") finally: if 'image_stream' in locals(): image_stream.close() - + def __getitem__(self, idx): row = self.df.iloc[idx] - low_res_image = self.load_image(row['image_512']) high_res_image = self.load_image(row['image_1024']) - augmented_low = self.augmentation(image=low_res_image) augmented_high = self.augmentation(image=high_res_image) return { @@ -66,10 +59,10 @@ class aiuNNDataset(torch.utils.data.Dataset): 'high_res': augmented_high['image'] } -def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): +def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False): + # Load and concatenate datasets. loaded_datasets = [aiuNNDataset(d) for d in datasets] combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) - train_size = int(0.8 * len(combined_dataset)) val_size = len(combined_dataset) - train_size train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size]) @@ -93,38 +86,57 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - # Limit VRAM usage to 95% of available memory (reducing risk of overflow) if device.type == 'cuda': torch.cuda.set_per_process_memory_fraction(0.95, device=device) - model = model.to(device) criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) - scaler = GradScaler() best_val_loss = float('inf') + # Import checkpoint if gradient checkpointing is desired + from torch.utils.checkpoint import checkpoint + for epoch in range(epochs): model.train() train_loss = 0.0 - for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"): + optimizer.zero_grad() + # Gradient accumulation over several steps (effective batch size = accumulation_steps) + for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1): if torch.cuda.is_available(): torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) - optimizer.zero_grad() + with autocast(): - outputs = model(low_res) - loss = criterion(outputs, high_res) + if use_checkpoint: + # Wrap the forward pass with checkpointing to save memory. + outputs = checkpoint(lambda x: model(x), low_res) + else: + outputs = model(low_res) + # Divide loss to average over accumulation steps. + loss = criterion(outputs, high_res) / accumulation_steps + scaler.scale(loss).backward() + train_loss += loss.item() * accumulation_steps # recover actual loss value + + # Update the optimizer every accumulation_steps iterations. + if i % accumulation_steps == 0: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + # In case remaining gradients are present from an incomplete accumulation round. + if (i % accumulation_steps) != 0: scaler.step(optimizer) scaler.update() - train_loss += loss.item() + optimizer.zero_grad() avg_train_loss = train_loss / len(train_loader) print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") + # Validation loop (without accumulation, using standard precision) model.eval() val_loss = 0.0 with torch.no_grad(): @@ -139,15 +151,19 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") + if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss model.save("best_model") + return model def main(): - BATCH_SIZE = 2 - model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") + BATCH_SIZE = 1 # Use a batch size of 1. + ACCUMULATION_STEPS = 8 # Accumulate gradients over 8 iterations for an effective batch size of 8. + USE_CHECKPOINT = False # Set to True to enable gradient checkpointing instead. + model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") if hasattr(model, 'chunked_'): model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) @@ -157,7 +173,10 @@ def main(): "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" ], - batch_size=BATCH_SIZE + batch_size=BATCH_SIZE, + epochs=10, + accumulation_steps=ACCUMULATION_STEPS, + use_checkpoint=USE_CHECKPOINT ) if __name__ == '__main__':