finetune_class #1
|
@ -12,7 +12,7 @@ import numpy as np
|
|||
from torch import nn
|
||||
from torch.utils.data import random_split, DataLoader
|
||||
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.amp import autocast, GradScaler
|
||||
from tqdm import tqdm
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
@ -87,7 +87,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
|
|||
)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# Fix: Pass the current device index (an integer) rather than a torch.device without index.
|
||||
if device.type == 'cuda':
|
||||
current_device = torch.cuda.current_device()
|
||||
torch.cuda.set_per_process_memory_fraction(0.95, device=current_device)
|
||||
|
@ -109,7 +108,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
|
|||
high_res = batch['high_res'].to(device)
|
||||
with autocast():
|
||||
if use_checkpoint:
|
||||
# Use checkpointing to save intermediate activations if needed.
|
||||
outputs = checkpoint(lambda x: model(x), low_res)
|
||||
else:
|
||||
outputs = model(low_res)
|
||||
|
@ -120,7 +118,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
|
|||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
# Handle leftover gradients
|
||||
if (i % accumulation_steps) != 0:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
@ -149,14 +146,24 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
|
|||
return model
|
||||
|
||||
def main():
|
||||
BATCH_SIZE = 2 # Use a batch size of 2.
|
||||
ACCUMULATION_STEPS = 8 # Accumulate gradients to simulate a larger batch.
|
||||
USE_CHECKPOINT = True # Set to True to enable gradient checkpointing if needed.
|
||||
BATCH_SIZE = 2
|
||||
ACCUMULATION_STEPS = 8
|
||||
USE_CHECKPOINT = True
|
||||
|
||||
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
|
||||
|
||||
# Upsample output if a 'chunked_' attribute exists, ensuring spatial dimensions match the high resolution images.
|
||||
if hasattr(model, 'chunked_'):
|
||||
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
|
||||
|
||||
# Append a final convolutional layer using values from model.config.
|
||||
# This converts the hidden feature maps (512 channels) to the 3 required output channels.
|
||||
final_conv = nn.Conv2d(model.config.hidden_size, model.config.num_channels, kernel_size=3, padding=1)
|
||||
model.add_module('final_layer', final_conv)
|
||||
|
||||
print("Modified model architecture:")
|
||||
print(model.config)
|
||||
|
||||
finetune_model(
|
||||
model=model,
|
||||
datasets=[
|
||||
|
|
Loading…
Reference in New Issue