improved vram usage?

This commit is contained in:
Falko Victor Habel 2025-02-23 22:26:48 +01:00
parent 7603ce8851
commit 86664b10a6
1 changed files with 26 additions and 16 deletions

View File

@ -10,6 +10,8 @@ from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
import gc
from aiia import AIIABase
from upsampler import Upsampler
@ -24,12 +26,10 @@ class EarlyStopping:
self.early_stop = False
def __call__(self, epoch_loss):
# If current loss is lower than the best loss minus min_delta, update best loss and reset counter.
if epoch_loss < self.best_loss - self.min_delta:
self.best_loss = epoch_loss
self.counter = 0
else:
# No significant improvement: increment counter.
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
@ -40,11 +40,9 @@ class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
# Load data with head() to limit rows for memory efficiency.
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(500)
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate that each row has proper image formats.
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
@ -69,7 +67,6 @@ class UpscaleDataset(Dataset):
return len(self.df)
def __getitem__(self, idx):
# Skip indices that have previously failed.
if idx in self.failed_indices:
return self[(idx + 1) % len(self)]
try:
@ -79,7 +76,6 @@ class UpscaleDataset(Dataset):
ImageFile.LOAD_TRUNCATED_IMAGES = True
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
# Validate expected sizes
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
if self.transform:
@ -91,7 +87,7 @@ class UpscaleDataset(Dataset):
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)]
# Define any transformations you require (e.g., converting PIL images to tensors)
# Define any transformations you require.
transform = transforms.Compose([
transforms.ToTensor(),
])
@ -100,15 +96,20 @@ transform = transforms.Compose([
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
base_model = AIIABase.load(pretrained_model_path)
model = Upsampler(base_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Move model to device using channels_last memory format.
model = model.to(device, memory_format=torch.channels_last)
# Optional: flag to enable gradient checkpointing.
use_checkpointing = True
# Create the dataset and dataloader.
dataset = UpscaleDataset([
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
], transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Consider adjusting num_workers if needed.
# Define loss function and optimizer.
criterion = nn.MSELoss()
@ -124,7 +125,6 @@ with open(csv_file, mode='a', newline='') as file:
if file.tell() == 0:
writer.writerow(['Epoch', 'Train Loss'])
# Initialize automatic mixed precision scaler and EarlyStopping.
scaler = GradScaler()
early_stopping = EarlyStopping(patience=3, min_delta=0.001)
@ -132,16 +132,20 @@ early_stopping = EarlyStopping(patience=3, min_delta=0.001)
for epoch in range(num_epochs):
epoch_loss = 0.0
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
print(f"Epoch: {epoch}")
print(f"Epoch: {epoch + 1}")
for low_res, high_res in progress_bar:
low_res = low_res.to(device, non_blocking=True)
# Move data to GPU with channels_last format where possible.
low_res = low_res.to(device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(device, non_blocking=True)
optimizer.zero_grad()
# Use automatic mixed precision to speed up training on supported hardware.
with autocast(device_type=device.type):
outputs = model(low_res)
if use_checkpointing:
# Wrap the forward pass with checkpointing to trade compute for memory.
outputs = checkpoint(lambda x: model(x), low_res)
else:
outputs = model(low_res)
loss = criterion(outputs, high_res)
scaler.scale(loss).backward()
@ -150,6 +154,13 @@ for epoch in range(num_epochs):
epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
# Optionally delete variables to free memory.
del low_res, high_res, outputs, loss
# Perform garbage collection and clear GPU cache after each epoch.
gc.collect()
torch.cuda.empty_cache()
print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")
@ -158,11 +169,10 @@ for epoch in range(num_epochs):
writer = csv.writer(file)
writer.writerow([epoch + 1, epoch_loss])
# Check early stopping criteria.
if early_stopping(epoch_loss):
print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}")
break
# Optionally, save the finetuned model using your library's save method.
# Optionally save the fine-tuned model.
finetuned_model_path = "aiuNN"
model.save(finetuned_model_path)