fixed checkpoint

This commit is contained in:
Falko Victor Habel 2025-02-14 22:14:09 +01:00
parent 5a6680178a
commit 9645e1da23
1 changed files with 14 additions and 24 deletions

View File

@ -14,6 +14,7 @@ from torch.utils.data import random_split, DataLoader
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
class aiuNNDataset(torch.utils.data.Dataset): class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path): def __init__(self, parquet_path):
@ -86,57 +87,48 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
) )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Fix: Pass the current device index (an integer) rather than a torch.device without index.
if device.type == 'cuda': if device.type == 'cuda':
torch.cuda.set_per_process_memory_fraction(0.95, device=device) current_device = torch.cuda.current_device()
model = model.to(device) torch.cuda.set_per_process_memory_fraction(0.95, device=current_device)
model = model.to(device)
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
scaler = GradScaler() scaler = GradScaler()
best_val_loss = float('inf') best_val_loss = float('inf')
# Import checkpoint if gradient checkpointing is desired
from torch.utils.checkpoint import checkpoint
for epoch in range(epochs): for epoch in range(epochs):
model.train() model.train()
train_loss = 0.0 train_loss = 0.0
optimizer.zero_grad() optimizer.zero_grad()
# Gradient accumulation over several steps (effective batch size = accumulation_steps)
for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1): for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
low_res = batch['low_res'].to(device) low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
with autocast(): with autocast():
if use_checkpoint: if use_checkpoint:
# Wrap the forward pass with checkpointing to save memory. # Use checkpointing to save intermediate activations if needed.
outputs = checkpoint(lambda x: model(x), low_res) outputs = checkpoint(lambda x: model(x), low_res)
else: else:
outputs = model(low_res) outputs = model(low_res)
# Divide loss to average over accumulation steps.
loss = criterion(outputs, high_res) / accumulation_steps loss = criterion(outputs, high_res) / accumulation_steps
scaler.scale(loss).backward() scaler.scale(loss).backward()
train_loss += loss.item() * accumulation_steps # recover actual loss value train_loss += loss.item() * accumulation_steps
# Update the optimizer every accumulation_steps iterations.
if i % accumulation_steps == 0: if i % accumulation_steps == 0:
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
# Handle leftover gradients
# In case remaining gradients are present from an incomplete accumulation round.
if (i % accumulation_steps) != 0: if (i % accumulation_steps) != 0:
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
avg_train_loss = train_loss / len(train_loader) avg_train_loss = train_loss / len(train_loader)
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
# Validation loop (without accumulation, using standard precision)
model.eval() model.eval()
val_loss = 0.0 val_loss = 0.0
with torch.no_grad(): with torch.no_grad():
@ -151,22 +143,20 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
val_loss += loss.item() val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader) avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss: if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss best_val_loss = avg_val_loss
model.save("best_model") model.save("best_model")
return model return model
def main(): def main():
BATCH_SIZE = 1 # Use a batch size of 1. BATCH_SIZE = 2 # Use a batch size of 2.
ACCUMULATION_STEPS = 8 # Accumulate gradients over 8 iterations for an effective batch size of 8. ACCUMULATION_STEPS = 8 # Accumulate gradients to simulate a larger batch.
USE_CHECKPOINT = False # Set to True to enable gradient checkpointing instead. USE_CHECKPOINT = True # Set to True to enable gradient checkpointing if needed.
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
if hasattr(model, 'chunked_'): if hasattr(model, 'chunked_'):
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
finetune_model( finetune_model(
model=model, model=model,
datasets=[ datasets=[
@ -178,6 +168,6 @@ def main():
accumulation_steps=ACCUMULATION_STEPS, accumulation_steps=ACCUMULATION_STEPS,
use_checkpoint=USE_CHECKPOINT use_checkpoint=USE_CHECKPOINT
) )
if __name__ == '__main__': if __name__ == '__main__':
main() main()