diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 817e080..0141a6d 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -14,6 +14,7 @@ from torch.utils.data import random_split, DataLoader from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from torch.cuda.amp import autocast, GradScaler from tqdm import tqdm +from torch.utils.checkpoint import checkpoint class aiuNNDataset(torch.utils.data.Dataset): def __init__(self, parquet_path): @@ -86,57 +87,48 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Fix: Pass the current device index (an integer) rather than a torch.device without index. if device.type == 'cuda': - torch.cuda.set_per_process_memory_fraction(0.95, device=device) - model = model.to(device) + current_device = torch.cuda.current_device() + torch.cuda.set_per_process_memory_fraction(0.95, device=current_device) + model = model.to(device) criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) scaler = GradScaler() best_val_loss = float('inf') - # Import checkpoint if gradient checkpointing is desired - from torch.utils.checkpoint import checkpoint - for epoch in range(epochs): model.train() train_loss = 0.0 optimizer.zero_grad() - # Gradient accumulation over several steps (effective batch size = accumulation_steps) for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1): if torch.cuda.is_available(): torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) - with autocast(): if use_checkpoint: - # Wrap the forward pass with checkpointing to save memory. + # Use checkpointing to save intermediate activations if needed. outputs = checkpoint(lambda x: model(x), low_res) else: outputs = model(low_res) - # Divide loss to average over accumulation steps. loss = criterion(outputs, high_res) / accumulation_steps - scaler.scale(loss).backward() - train_loss += loss.item() * accumulation_steps # recover actual loss value - - # Update the optimizer every accumulation_steps iterations. + train_loss += loss.item() * accumulation_steps if i % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() - - # In case remaining gradients are present from an incomplete accumulation round. + # Handle leftover gradients if (i % accumulation_steps) != 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() - + avg_train_loss = train_loss / len(train_loader) print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") - # Validation loop (without accumulation, using standard precision) model.eval() val_loss = 0.0 with torch.no_grad(): @@ -151,22 +143,20 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") - if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss model.save("best_model") - return model def main(): - BATCH_SIZE = 1 # Use a batch size of 1. - ACCUMULATION_STEPS = 8 # Accumulate gradients over 8 iterations for an effective batch size of 8. - USE_CHECKPOINT = False # Set to True to enable gradient checkpointing instead. + BATCH_SIZE = 2 # Use a batch size of 2. + ACCUMULATION_STEPS = 8 # Accumulate gradients to simulate a larger batch. + USE_CHECKPOINT = True # Set to True to enable gradient checkpointing if needed. model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") if hasattr(model, 'chunked_'): model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) - + finetune_model( model=model, datasets=[ @@ -178,6 +168,6 @@ def main(): accumulation_steps=ACCUMULATION_STEPS, use_checkpoint=USE_CHECKPOINT ) - + if __name__ == '__main__': main()