finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 43 additions and 24 deletions
Showing only changes of commit 5a6680178a - Show all commits

View File

@ -12,13 +12,12 @@ import numpy as np
from torch import nn from torch import nn
from torch.utils.data import random_split, DataLoader from torch.utils.data import random_split, DataLoader
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
from torch.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm from tqdm import tqdm
class aiuNNDataset(torch.utils.data.Dataset): class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path): def __init__(self, parquet_path):
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2000) self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2000)
self.augmentation = Compose([ self.augmentation = Compose([
RandomBrightnessContrast(p=0.5), RandomBrightnessContrast(p=0.5),
HorizontalFlip(p=0.5), HorizontalFlip(p=0.5),
@ -36,16 +35,12 @@ class aiuNNDataset(torch.utils.data.Dataset):
try: try:
if isinstance(image_data, str): if isinstance(image_data, str):
image_data = base64.b64decode(image_data) image_data = base64.b64decode(image_data)
if not isinstance(image_data, bytes): if not isinstance(image_data, bytes):
raise ValueError("Invalid image data format") raise ValueError("Invalid image data format")
image_stream = io.BytesIO(image_data) image_stream = io.BytesIO(image_data)
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
image = Image.open(image_stream).convert('RGB') image = Image.open(image_stream).convert('RGB')
image_array = np.array(image) image_array = np.array(image)
return image_array return image_array
except Exception as e: except Exception as e:
raise RuntimeError(f"Error loading image: {str(e)}") raise RuntimeError(f"Error loading image: {str(e)}")
@ -55,10 +50,8 @@ class aiuNNDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
row = self.df.iloc[idx] row = self.df.iloc[idx]
low_res_image = self.load_image(row['image_512']) low_res_image = self.load_image(row['image_512'])
high_res_image = self.load_image(row['image_1024']) high_res_image = self.load_image(row['image_1024'])
augmented_low = self.augmentation(image=low_res_image) augmented_low = self.augmentation(image=low_res_image)
augmented_high = self.augmentation(image=high_res_image) augmented_high = self.augmentation(image=high_res_image)
return { return {
@ -66,10 +59,10 @@ class aiuNNDataset(torch.utils.data.Dataset):
'high_res': augmented_high['image'] 'high_res': augmented_high['image']
} }
def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False):
# Load and concatenate datasets.
loaded_datasets = [aiuNNDataset(d) for d in datasets] loaded_datasets = [aiuNNDataset(d) for d in datasets]
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
train_size = int(0.8 * len(combined_dataset)) train_size = int(0.8 * len(combined_dataset))
val_size = len(combined_dataset) - train_size val_size = len(combined_dataset) - train_size
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size]) train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
@ -93,38 +86,57 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
) )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Limit VRAM usage to 95% of available memory (reducing risk of overflow)
if device.type == 'cuda': if device.type == 'cuda':
torch.cuda.set_per_process_memory_fraction(0.95, device=device) torch.cuda.set_per_process_memory_fraction(0.95, device=device)
model = model.to(device) model = model.to(device)
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
scaler = GradScaler() scaler = GradScaler()
best_val_loss = float('inf') best_val_loss = float('inf')
# Import checkpoint if gradient checkpointing is desired
from torch.utils.checkpoint import checkpoint
for epoch in range(epochs): for epoch in range(epochs):
model.train() model.train()
train_loss = 0.0 train_loss = 0.0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"): optimizer.zero_grad()
# Gradient accumulation over several steps (effective batch size = accumulation_steps)
for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"), start=1):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
low_res = batch['low_res'].to(device) low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
optimizer.zero_grad()
with autocast(): with autocast():
if use_checkpoint:
# Wrap the forward pass with checkpointing to save memory.
outputs = checkpoint(lambda x: model(x), low_res)
else:
outputs = model(low_res) outputs = model(low_res)
loss = criterion(outputs, high_res) # Divide loss to average over accumulation steps.
loss = criterion(outputs, high_res) / accumulation_steps
scaler.scale(loss).backward() scaler.scale(loss).backward()
train_loss += loss.item() * accumulation_steps # recover actual loss value
# Update the optimizer every accumulation_steps iterations.
if i % accumulation_steps == 0:
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
train_loss += loss.item() optimizer.zero_grad()
# In case remaining gradients are present from an incomplete accumulation round.
if (i % accumulation_steps) != 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
avg_train_loss = train_loss / len(train_loader) avg_train_loss = train_loss / len(train_loader)
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
# Validation loop (without accumulation, using standard precision)
model.eval() model.eval()
val_loss = 0.0 val_loss = 0.0
with torch.no_grad(): with torch.no_grad():
@ -139,15 +151,19 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
val_loss += loss.item() val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader) avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss: if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss best_val_loss = avg_val_loss
model.save("best_model") model.save("best_model")
return model return model
def main(): def main():
BATCH_SIZE = 2 BATCH_SIZE = 1 # Use a batch size of 1.
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") ACCUMULATION_STEPS = 8 # Accumulate gradients over 8 iterations for an effective batch size of 8.
USE_CHECKPOINT = False # Set to True to enable gradient checkpointing instead.
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
if hasattr(model, 'chunked_'): if hasattr(model, 'chunked_'):
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
@ -157,7 +173,10 @@ def main():
"/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet" "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
], ],
batch_size=BATCH_SIZE batch_size=BATCH_SIZE,
epochs=10,
accumulation_steps=ACCUMULATION_STEPS,
use_checkpoint=USE_CHECKPOINT
) )
if __name__ == '__main__': if __name__ == '__main__':