From 86664b10a6f870950b8489d347f0137f6ebb83d1 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 23 Feb 2025 22:26:48 +0100 Subject: [PATCH] improved vram usage? --- src/aiunn/finetune.py | 42 ++++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 9d83ba9..f0bb3e8 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -10,6 +10,8 @@ from torch.amp import autocast, GradScaler from torch.utils.data import Dataset, DataLoader from torchvision import transforms from tqdm import tqdm +from torch.utils.checkpoint import checkpoint +import gc from aiia import AIIABase from upsampler import Upsampler @@ -24,12 +26,10 @@ class EarlyStopping: self.early_stop = False def __call__(self, epoch_loss): - # If current loss is lower than the best loss minus min_delta, update best loss and reset counter. if epoch_loss < self.best_loss - self.min_delta: self.best_loss = epoch_loss self.counter = 0 else: - # No significant improvement: increment counter. self.counter += 1 if self.counter >= self.patience: self.early_stop = True @@ -40,11 +40,9 @@ class UpscaleDataset(Dataset): def __init__(self, parquet_files: list, transform=None): combined_df = pd.DataFrame() for parquet_file in parquet_files: - # Load data with head() to limit rows for memory efficiency. df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(500) combined_df = pd.concat([combined_df, df], ignore_index=True) - # Validate that each row has proper image formats. self.df = combined_df.apply(self._validate_row, axis=1) self.transform = transform self.failed_indices = set() @@ -69,7 +67,6 @@ class UpscaleDataset(Dataset): return len(self.df) def __getitem__(self, idx): - # Skip indices that have previously failed. if idx in self.failed_indices: return self[(idx + 1) % len(self)] try: @@ -79,7 +76,6 @@ class UpscaleDataset(Dataset): ImageFile.LOAD_TRUNCATED_IMAGES = True low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') - # Validate expected sizes if low_res.size != (512, 512) or high_res.size != (1024, 1024): raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}") if self.transform: @@ -91,7 +87,7 @@ class UpscaleDataset(Dataset): self.failed_indices.add(idx) return self[(idx + 1) % len(self)] -# Define any transformations you require (e.g., converting PIL images to tensors) +# Define any transformations you require. transform = transforms.Compose([ transforms.ToTensor(), ]) @@ -100,15 +96,20 @@ transform = transforms.Compose([ pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" base_model = AIIABase.load(pretrained_model_path) model = Upsampler(base_model) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -model = model.to(device) +# Move model to device using channels_last memory format. +model = model.to(device, memory_format=torch.channels_last) + +# Optional: flag to enable gradient checkpointing. +use_checkpointing = True # Create the dataset and dataloader. dataset = UpscaleDataset([ "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" ], transform=transform) -data_loader = DataLoader(dataset, batch_size=1, shuffle=True) +data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Consider adjusting num_workers if needed. # Define loss function and optimizer. criterion = nn.MSELoss() @@ -124,7 +125,6 @@ with open(csv_file, mode='a', newline='') as file: if file.tell() == 0: writer.writerow(['Epoch', 'Train Loss']) -# Initialize automatic mixed precision scaler and EarlyStopping. scaler = GradScaler() early_stopping = EarlyStopping(patience=3, min_delta=0.001) @@ -132,16 +132,20 @@ early_stopping = EarlyStopping(patience=3, min_delta=0.001) for epoch in range(num_epochs): epoch_loss = 0.0 progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}") - print(f"Epoch: {epoch}") + print(f"Epoch: {epoch + 1}") for low_res, high_res in progress_bar: - low_res = low_res.to(device, non_blocking=True) + # Move data to GPU with channels_last format where possible. + low_res = low_res.to(device, non_blocking=True).to(memory_format=torch.channels_last) high_res = high_res.to(device, non_blocking=True) optimizer.zero_grad() - # Use automatic mixed precision to speed up training on supported hardware. with autocast(device_type=device.type): - outputs = model(low_res) + if use_checkpointing: + # Wrap the forward pass with checkpointing to trade compute for memory. + outputs = checkpoint(lambda x: model(x), low_res) + else: + outputs = model(low_res) loss = criterion(outputs, high_res) scaler.scale(loss).backward() @@ -150,6 +154,13 @@ for epoch in range(num_epochs): epoch_loss += loss.item() progress_bar.set_postfix({'loss': loss.item()}) + + # Optionally delete variables to free memory. + del low_res, high_res, outputs, loss + + # Perform garbage collection and clear GPU cache after each epoch. + gc.collect() + torch.cuda.empty_cache() print(f"Epoch {epoch + 1}, Loss: {epoch_loss}") @@ -158,11 +169,10 @@ for epoch in range(num_epochs): writer = csv.writer(file) writer.writerow([epoch + 1, epoch_loss]) - # Check early stopping criteria. if early_stopping(epoch_loss): print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}") break -# Optionally, save the finetuned model using your library's save method. +# Optionally save the fine-tuned model. finetuned_model_path = "aiuNN" model.save(finetuned_model_path)