improved vram usage?
This commit is contained in:
parent
7603ce8851
commit
86664b10a6
|
@ -10,6 +10,8 @@ from torch.amp import autocast, GradScaler
|
|||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import gc
|
||||
|
||||
from aiia import AIIABase
|
||||
from upsampler import Upsampler
|
||||
|
@ -24,12 +26,10 @@ class EarlyStopping:
|
|||
self.early_stop = False
|
||||
|
||||
def __call__(self, epoch_loss):
|
||||
# If current loss is lower than the best loss minus min_delta, update best loss and reset counter.
|
||||
if epoch_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = epoch_loss
|
||||
self.counter = 0
|
||||
else:
|
||||
# No significant improvement: increment counter.
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
|
@ -40,11 +40,9 @@ class UpscaleDataset(Dataset):
|
|||
def __init__(self, parquet_files: list, transform=None):
|
||||
combined_df = pd.DataFrame()
|
||||
for parquet_file in parquet_files:
|
||||
# Load data with head() to limit rows for memory efficiency.
|
||||
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(500)
|
||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||
|
||||
# Validate that each row has proper image formats.
|
||||
self.df = combined_df.apply(self._validate_row, axis=1)
|
||||
self.transform = transform
|
||||
self.failed_indices = set()
|
||||
|
@ -69,7 +67,6 @@ class UpscaleDataset(Dataset):
|
|||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Skip indices that have previously failed.
|
||||
if idx in self.failed_indices:
|
||||
return self[(idx + 1) % len(self)]
|
||||
try:
|
||||
|
@ -79,7 +76,6 @@ class UpscaleDataset(Dataset):
|
|||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
|
||||
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
|
||||
# Validate expected sizes
|
||||
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
|
||||
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
|
||||
if self.transform:
|
||||
|
@ -91,7 +87,7 @@ class UpscaleDataset(Dataset):
|
|||
self.failed_indices.add(idx)
|
||||
return self[(idx + 1) % len(self)]
|
||||
|
||||
# Define any transformations you require (e.g., converting PIL images to tensors)
|
||||
# Define any transformations you require.
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
@ -100,15 +96,20 @@ transform = transforms.Compose([
|
|||
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
||||
base_model = AIIABase.load(pretrained_model_path)
|
||||
model = Upsampler(base_model)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
# Move model to device using channels_last memory format.
|
||||
model = model.to(device, memory_format=torch.channels_last)
|
||||
|
||||
# Optional: flag to enable gradient checkpointing.
|
||||
use_checkpointing = True
|
||||
|
||||
# Create the dataset and dataloader.
|
||||
dataset = UpscaleDataset([
|
||||
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||
], transform=transform)
|
||||
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||
data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Consider adjusting num_workers if needed.
|
||||
|
||||
# Define loss function and optimizer.
|
||||
criterion = nn.MSELoss()
|
||||
|
@ -124,7 +125,6 @@ with open(csv_file, mode='a', newline='') as file:
|
|||
if file.tell() == 0:
|
||||
writer.writerow(['Epoch', 'Train Loss'])
|
||||
|
||||
# Initialize automatic mixed precision scaler and EarlyStopping.
|
||||
scaler = GradScaler()
|
||||
early_stopping = EarlyStopping(patience=3, min_delta=0.001)
|
||||
|
||||
|
@ -132,16 +132,20 @@ early_stopping = EarlyStopping(patience=3, min_delta=0.001)
|
|||
for epoch in range(num_epochs):
|
||||
epoch_loss = 0.0
|
||||
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
|
||||
print(f"Epoch: {epoch}")
|
||||
print(f"Epoch: {epoch + 1}")
|
||||
for low_res, high_res in progress_bar:
|
||||
low_res = low_res.to(device, non_blocking=True)
|
||||
# Move data to GPU with channels_last format where possible.
|
||||
low_res = low_res.to(device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||
high_res = high_res.to(device, non_blocking=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Use automatic mixed precision to speed up training on supported hardware.
|
||||
with autocast(device_type=device.type):
|
||||
outputs = model(low_res)
|
||||
if use_checkpointing:
|
||||
# Wrap the forward pass with checkpointing to trade compute for memory.
|
||||
outputs = checkpoint(lambda x: model(x), low_res)
|
||||
else:
|
||||
outputs = model(low_res)
|
||||
loss = criterion(outputs, high_res)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
@ -150,6 +154,13 @@ for epoch in range(num_epochs):
|
|||
|
||||
epoch_loss += loss.item()
|
||||
progress_bar.set_postfix({'loss': loss.item()})
|
||||
|
||||
# Optionally delete variables to free memory.
|
||||
del low_res, high_res, outputs, loss
|
||||
|
||||
# Perform garbage collection and clear GPU cache after each epoch.
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")
|
||||
|
||||
|
@ -158,11 +169,10 @@ for epoch in range(num_epochs):
|
|||
writer = csv.writer(file)
|
||||
writer.writerow([epoch + 1, epoch_loss])
|
||||
|
||||
# Check early stopping criteria.
|
||||
if early_stopping(epoch_loss):
|
||||
print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}")
|
||||
break
|
||||
|
||||
# Optionally, save the finetuned model using your library's save method.
|
||||
# Optionally save the fine-tuned model.
|
||||
finetuned_model_path = "aiuNN"
|
||||
model.save(finetuned_model_path)
|
||||
|
|
Loading…
Reference in New Issue