diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 0141a6d..92fa954 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -12,7 +12,7 @@ import numpy as np from torch import nn from torch.utils.data import random_split, DataLoader from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive -from torch.cuda.amp import autocast, GradScaler +from torch.amp import autocast, GradScaler from tqdm import tqdm from torch.utils.checkpoint import checkpoint @@ -87,7 +87,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - # Fix: Pass the current device index (an integer) rather than a torch.device without index. if device.type == 'cuda': current_device = torch.cuda.current_device() torch.cuda.set_per_process_memory_fraction(0.95, device=current_device) @@ -109,7 +108,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac high_res = batch['high_res'].to(device) with autocast(): if use_checkpoint: - # Use checkpointing to save intermediate activations if needed. outputs = checkpoint(lambda x: model(x), low_res) else: outputs = model(low_res) @@ -120,7 +118,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac scaler.step(optimizer) scaler.update() optimizer.zero_grad() - # Handle leftover gradients if (i % accumulation_steps) != 0: scaler.step(optimizer) scaler.update() @@ -149,14 +146,24 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac return model def main(): - BATCH_SIZE = 2 # Use a batch size of 2. - ACCUMULATION_STEPS = 8 # Accumulate gradients to simulate a larger batch. - USE_CHECKPOINT = True # Set to True to enable gradient checkpointing if needed. - + BATCH_SIZE = 2 + ACCUMULATION_STEPS = 8 + USE_CHECKPOINT = True + model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") + + # Upsample output if a 'chunked_' attribute exists, ensuring spatial dimensions match the high resolution images. if hasattr(model, 'chunked_'): model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) + # Append a final convolutional layer using values from model.config. + # This converts the hidden feature maps (512 channels) to the 3 required output channels. + final_conv = nn.Conv2d(model.config.hidden_size, model.config.num_channels, kernel_size=3, padding=1) + model.add_module('final_layer', final_conv) + + print("Modified model architecture:") + print(model.config) + finetune_model( model=model, datasets=[