try checkpoints

This commit is contained in:
Falko Victor Habel 2025-02-14 22:10:07 +01:00
parent e7e7e96001
commit 5a6680178a
1 changed files with 43 additions and 24 deletions

View File

@ -12,13 +12,12 @@ import numpy as np
from torch import nn
from torch.utils.data import random_split, DataLoader
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
from torch.amp import autocast, GradScaler
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path):
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2000)
self.augmentation = Compose([
RandomBrightnessContrast(p=0.5),
HorizontalFlip(p=0.5),
@ -28,37 +27,31 @@ class aiuNNDataset(torch.utils.data.Dataset):
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ToTensorV2()
])
def __len__(self):
return len(self.df)
def load_image(self, image_data):
try:
if isinstance(image_data, str):
image_data = base64.b64decode(image_data)
if not isinstance(image_data, bytes):
raise ValueError("Invalid image data format")
image_stream = io.BytesIO(image_data)
ImageFile.LOAD_TRUNCATED_IMAGES = True
image = Image.open(image_stream).convert('RGB')
image_array = np.array(image)
return image_array
except Exception as e:
raise RuntimeError(f"Error loading image: {str(e)}")
finally:
if 'image_stream' in locals():
image_stream.close()
def __getitem__(self, idx):
row = self.df.iloc[idx]
low_res_image = self.load_image(row['image_512'])
high_res_image = self.load_image(row['image_1024'])
augmented_low = self.augmentation(image=low_res_image)
augmented_high = self.augmentation(image=high_res_image)
return {
@ -66,10 +59,10 @@ class aiuNNDataset(torch.utils.data.Dataset):
'high_res': augmented_high['image']
}
def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False):
# Load and concatenate datasets.
loaded_datasets = [aiuNNDataset(d) for d in datasets]
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
train_size = int(0.8 * len(combined_dataset))
val_size = len(combined_dataset) - train_size
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
@ -93,38 +86,57 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Limit VRAM usage to 95% of available memory (reducing risk of overflow)
if device.type == 'cuda':
torch.cuda.set_per_process_memory_fraction(0.95, device=device)
model = model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
scaler = GradScaler()
best_val_loss = float('inf')
# Import checkpoint if gradient checkpointing is desired
from torch.utils.checkpoint import checkpoint
for epoch in range(epochs):
model.train()
train_loss = 0.0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"):
optimizer.zero_grad()
# Gradient accumulation over several steps (effective batch size = accumulation_steps)
for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1):
if torch.cuda.is_available():
torch.cuda.empty_cache()
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
optimizer.zero_grad()
with autocast():
outputs = model(low_res)
loss = criterion(outputs, high_res)
if use_checkpoint:
# Wrap the forward pass with checkpointing to save memory.
outputs = checkpoint(lambda x: model(x), low_res)
else:
outputs = model(low_res)
# Divide loss to average over accumulation steps.
loss = criterion(outputs, high_res) / accumulation_steps
scaler.scale(loss).backward()
train_loss += loss.item() * accumulation_steps # recover actual loss value
# Update the optimizer every accumulation_steps iterations.
if i % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# In case remaining gradients are present from an incomplete accumulation round.
if (i % accumulation_steps) != 0:
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
optimizer.zero_grad()
avg_train_loss = train_loss / len(train_loader)
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
# Validation loop (without accumulation, using standard precision)
model.eval()
val_loss = 0.0
with torch.no_grad():
@ -139,15 +151,19 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("best_model")
return model
def main():
BATCH_SIZE = 2
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
BATCH_SIZE = 1 # Use a batch size of 1.
ACCUMULATION_STEPS = 8 # Accumulate gradients over 8 iterations for an effective batch size of 8.
USE_CHECKPOINT = False # Set to True to enable gradient checkpointing instead.
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
if hasattr(model, 'chunked_'):
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
@ -157,7 +173,10 @@ def main():
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
],
batch_size=BATCH_SIZE
batch_size=BATCH_SIZE,
epochs=10,
accumulation_steps=ACCUMULATION_STEPS,
use_checkpoint=USE_CHECKPOINT
)
if __name__ == '__main__':