develop #4
|
@ -0,0 +1,144 @@
|
|||
import torch
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
import io
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
from aiia.model import AIIABase
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
# Step 1: Define Custom Dataset Class
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, dataframe, transform=None):
|
||||
self.dataframe = dataframe
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataframe)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
row = self.dataframe.iloc[idx]
|
||||
|
||||
# Decode image_512 from bytes
|
||||
img_bytes = row['image_512']
|
||||
img_stream = io.BytesIO(img_bytes)
|
||||
low_res_image = Image.open(img_stream).convert('RGB')
|
||||
|
||||
# Decode image_1024 from bytes
|
||||
high_res_bytes = row['image_1024']
|
||||
high_stream = io.BytesIO(high_res_bytes)
|
||||
high_res_image = Image.open(high_stream).convert('RGB')
|
||||
|
||||
# Apply transformations if specified
|
||||
if self.transform:
|
||||
low_res_image = self.transform(low_res_image)
|
||||
high_res_image = self.transform(high_res_image)
|
||||
|
||||
return {'low_res': low_res_image, 'high_res': high_res_image}
|
||||
|
||||
|
||||
|
||||
# Step 2: Load and Preprocess Data
|
||||
# Read the dataset (assuming it's a DataFrame with columns 'image_512' and 'image_1024')
|
||||
df1 = pd.read_parquet('/root/training_data/vision-dataset/image_upscaler.parquet')
|
||||
df2 = pd.read_parquet('/root/training_data/vision-dataset/image_vec_upscaler.parquet')
|
||||
|
||||
# Combine the two datasets into one DataFrame
|
||||
df = pd.concat([df1, df2], ignore_index=True)
|
||||
|
||||
# Split into training and validation sets
|
||||
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
|
||||
|
||||
# Define preprocessing transforms
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
])
|
||||
|
||||
train_dataset = ImageDataset(train_df, transform=transform)
|
||||
val_dataset = ImageDataset(val_df, transform=transform)
|
||||
|
||||
# Create DataLoaders
|
||||
batch_size = 2
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
||||
|
||||
# Step 3: Load Pre-trained Model and Modify for Upscaling
|
||||
model = AIIABase.load("AIIA-Base-512")
|
||||
|
||||
# Freeze original CNN layers to prevent catastrophic forgetting
|
||||
for param in model.cnn.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Add upsample module
|
||||
hidden_size = model.config.hidden_size # Assuming this is defined in your model's config
|
||||
model.upsample = torch.nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
||||
nn.Conv2d(hidden_size, 3, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
# Step 4: Define Loss Function and Optimizer
|
||||
criterion = torch.nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Adjust learning rate as needed
|
||||
|
||||
# Alternatively, if you want to train only the new layers:
|
||||
params_to_update = []
|
||||
for name, param in model.named_parameters():
|
||||
if 'upsample' in name:
|
||||
params_to_update.append(param)
|
||||
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
|
||||
|
||||
# Step 5: Training Loop
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
best_val_loss = float('inf')
|
||||
num_epochs = 10 # Adjust as needed
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
|
||||
for batch in train_loader:
|
||||
low_res = batch['low_res'].to(device)
|
||||
high_res = batch['high_res'].to(device)
|
||||
|
||||
# Forward pass
|
||||
features = model.cnn(low_res)
|
||||
outputs = model.upsample(features)
|
||||
|
||||
loss = criterion(outputs, high_res)
|
||||
|
||||
# Backward pass and optimize
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
|
||||
epoch_loss = running_loss / len(train_loader)
|
||||
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
|
||||
|
||||
# Validation Step
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
low_res = batch['low_res'].to(device)
|
||||
high_res = batch['high_res'].to(device)
|
||||
|
||||
features = model.cnn(low_res)
|
||||
outputs = model.upsample(features)
|
||||
|
||||
loss = criterion(outputs, high_res)
|
||||
val_loss += loss.item()
|
||||
|
||||
print(f"Validation Loss: {val_loss:.4f}")
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
model.save("AIIA-base-512-upscaler")
|
||||
print("Best model saved!")
|
Loading…
Reference in New Issue