develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 15 additions and 8 deletions
Showing only changes of commit ca29ee748b - Show all commits

View File

@ -12,7 +12,7 @@ import numpy as np
from torch import nn from torch import nn
from torch.utils.data import random_split, DataLoader from torch.utils.data import random_split, DataLoader
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
from torch.cuda.amp import autocast, GradScaler from torch.amp import autocast, GradScaler
from tqdm import tqdm from tqdm import tqdm
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
@ -87,7 +87,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
) )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Fix: Pass the current device index (an integer) rather than a torch.device without index.
if device.type == 'cuda': if device.type == 'cuda':
current_device = torch.cuda.current_device() current_device = torch.cuda.current_device()
torch.cuda.set_per_process_memory_fraction(0.95, device=current_device) torch.cuda.set_per_process_memory_fraction(0.95, device=current_device)
@ -109,7 +108,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
with autocast(): with autocast():
if use_checkpoint: if use_checkpoint:
# Use checkpointing to save intermediate activations if needed.
outputs = checkpoint(lambda x: model(x), low_res) outputs = checkpoint(lambda x: model(x), low_res)
else: else:
outputs = model(low_res) outputs = model(low_res)
@ -120,7 +118,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
# Handle leftover gradients
if (i % accumulation_steps) != 0: if (i % accumulation_steps) != 0:
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
@ -149,14 +146,24 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
return model return model
def main(): def main():
BATCH_SIZE = 2 # Use a batch size of 2. BATCH_SIZE = 2
ACCUMULATION_STEPS = 8 # Accumulate gradients to simulate a larger batch. ACCUMULATION_STEPS = 8
USE_CHECKPOINT = True # Set to True to enable gradient checkpointing if needed. USE_CHECKPOINT = True
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
# Upsample output if a 'chunked_' attribute exists, ensuring spatial dimensions match the high resolution images.
if hasattr(model, 'chunked_'): if hasattr(model, 'chunked_'):
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
# Append a final convolutional layer using values from model.config.
# This converts the hidden feature maps (512 channels) to the 3 required output channels.
final_conv = nn.Conv2d(model.config.hidden_size, model.config.num_channels, kernel_size=3, padding=1)
model.add_module('final_layer', final_conv)
print("Modified model architecture:")
print(model.config)
finetune_model( finetune_model(
model=model, model=model,
datasets=[ datasets=[