finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 14 additions and 24 deletions
Showing only changes of commit 9645e1da23 - Show all commits

View File

@ -14,6 +14,7 @@ from torch.utils.data import random_split, DataLoader
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path):
@ -86,48 +87,40 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Fix: Pass the current device index (an integer) rather than a torch.device without index.
if device.type == 'cuda':
torch.cuda.set_per_process_memory_fraction(0.95, device=device)
model = model.to(device)
current_device = torch.cuda.current_device()
torch.cuda.set_per_process_memory_fraction(0.95, device=current_device)
model = model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
scaler = GradScaler()
best_val_loss = float('inf')
# Import checkpoint if gradient checkpointing is desired
from torch.utils.checkpoint import checkpoint
for epoch in range(epochs):
model.train()
train_loss = 0.0
optimizer.zero_grad()
# Gradient accumulation over several steps (effective batch size = accumulation_steps)
for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1):
if torch.cuda.is_available():
torch.cuda.empty_cache()
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
with autocast():
if use_checkpoint:
# Wrap the forward pass with checkpointing to save memory.
# Use checkpointing to save intermediate activations if needed.
outputs = checkpoint(lambda x: model(x), low_res)
else:
outputs = model(low_res)
# Divide loss to average over accumulation steps.
loss = criterion(outputs, high_res) / accumulation_steps
scaler.scale(loss).backward()
train_loss += loss.item() * accumulation_steps # recover actual loss value
# Update the optimizer every accumulation_steps iterations.
train_loss += loss.item() * accumulation_steps
if i % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# In case remaining gradients are present from an incomplete accumulation round.
# Handle leftover gradients
if (i % accumulation_steps) != 0:
scaler.step(optimizer)
scaler.update()
@ -136,7 +129,6 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
avg_train_loss = train_loss / len(train_loader)
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
# Validation loop (without accumulation, using standard precision)
model.eval()
val_loss = 0.0
with torch.no_grad():
@ -151,17 +143,15 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("best_model")
return model
def main():
BATCH_SIZE = 1 # Use a batch size of 1.
ACCUMULATION_STEPS = 8 # Accumulate gradients over 8 iterations for an effective batch size of 8.
USE_CHECKPOINT = False # Set to True to enable gradient checkpointing instead.
BATCH_SIZE = 2 # Use a batch size of 2.
ACCUMULATION_STEPS = 8 # Accumulate gradients to simulate a larger batch.
USE_CHECKPOINT = True # Set to True to enable gradient checkpointing if needed.
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
if hasattr(model, 'chunked_'):