updated dataloader to work with tupels
This commit is contained in:
parent
3f6e6514a9
commit
7c4aef0978
|
@ -106,10 +106,11 @@ class JPGImageLoader:
|
||||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
|
||||||
class AIIADataLoader:
|
class AIIADataLoader:
|
||||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs):
|
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.pretraining = pretraining
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
sample_value = dataset[column].iloc[0]
|
sample_value = dataset[column].iloc[0]
|
||||||
|
@ -134,7 +135,12 @@ class AIIADataLoader:
|
||||||
for idx in range(len(dataset)):
|
for idx in range(len(dataset)):
|
||||||
item = self.loader.get_item(idx)
|
item = self.loader.get_item(idx)
|
||||||
if item is not None:
|
if item is not None:
|
||||||
self.items.append(item)
|
if self.pretraining:
|
||||||
|
img = item[0] if isinstance(item, tuple) else item
|
||||||
|
self.items.append((img, 'denoise', img))
|
||||||
|
self.items.append((img, 'rotate', 0))
|
||||||
|
else:
|
||||||
|
self.items.append(item)
|
||||||
|
|
||||||
if not self.items:
|
if not self.items:
|
||||||
raise ValueError("No valid items were loaded from the dataset")
|
raise ValueError("No valid items were loaded from the dataset")
|
||||||
|
@ -163,12 +169,14 @@ class AIIADataLoader:
|
||||||
|
|
||||||
def _create_subset(self, indices):
|
def _create_subset(self, indices):
|
||||||
subset_items = [self.items[i] for i in indices]
|
subset_items = [self.items[i] for i in indices]
|
||||||
return AIIADataset(subset_items)
|
return AIIADataset(subset_items, pretraining=self.pretraining)
|
||||||
|
|
||||||
class AIIADataset(torch.utils.data.Dataset):
|
class AIIADataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, items):
|
def __init__(self, items, pretraining=False):
|
||||||
self.items = items
|
self.items = items
|
||||||
|
self.pretraining = pretraining
|
||||||
self.transform = transforms.Compose([
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
transforms.ToTensor()
|
transforms.ToTensor()
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@ -177,29 +185,29 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.items[idx]
|
item = self.items[idx]
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
|
||||||
image, label = item
|
if self.pretraining:
|
||||||
image = self.transform(image)
|
|
||||||
return (image, label)
|
|
||||||
elif isinstance(item, tuple) and len(item) == 3:
|
|
||||||
image, task, label = item
|
image, task, label = item
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
noise_std = 0.1
|
noise_std = 0.1
|
||||||
noisy_img = image + torch.randn_like(image) * noise_std
|
noisy_img = image + torch.randn_like(image) * noise_std
|
||||||
target = image
|
target = image.clone()
|
||||||
return (noisy_img, target, task)
|
return noisy_img, target, task
|
||||||
elif task == 'rotate':
|
elif task == 'rotate':
|
||||||
angles = [0, 90, 180, 270]
|
angles = [0, 90, 180, 270]
|
||||||
angle = random.choice(angles)
|
angle = random.choice(angles)
|
||||||
rotated_img = transforms.functional.rotate(image, angle)
|
rotated_img = transforms.functional.rotate(image, angle)
|
||||||
target = torch.tensor(angle).long()
|
target = torch.tensor(angle / 90).long()
|
||||||
return (rotated_img, target, task)
|
return rotated_img, target, task
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown task: {task}")
|
|
||||||
else:
|
else:
|
||||||
if isinstance(item, Image.Image):
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
return self.transform(item)
|
image, label = item
|
||||||
|
image = self.transform(image)
|
||||||
|
return image, label
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid item format.")
|
if isinstance(item, Image.Image):
|
||||||
|
return self.transform(item)
|
||||||
|
else:
|
||||||
|
return self.transform(item[0])
|
||||||
|
|
|
@ -1,37 +1,26 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from torchvision import transforms
|
|
||||||
from PIL import Image
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from aiia.model.config import AIIAConfig
|
from aiia.model.config import AIIAConfig
|
||||||
from aiia.model import AIIABase
|
from aiia.model import AIIABase
|
||||||
from aiia.data.DataLoader import AIIADataLoader
|
from aiia.data.DataLoader import AIIADataLoader
|
||||||
|
|
||||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
# Merge the two parquet files
|
|
||||||
df1 = pd.read_parquet(data_path1).head(10000)
|
df1 = pd.read_parquet(data_path1).head(10000)
|
||||||
df2 = pd.read_parquet(data_path2).head(10000)
|
df2 = pd.read_parquet(data_path2).head(10000)
|
||||||
merged_df = pd.concat([df1, df2], ignore_index=True)
|
merged_df = pd.concat([df1, df2], ignore_index=True)
|
||||||
|
|
||||||
# Create a new AIIAConfig instance
|
|
||||||
config = AIIAConfig(
|
config = AIIAConfig(
|
||||||
model_name="AIIA-Base-512x20k",
|
model_name="AIIA-Base-512x20k",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the base model
|
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
|
|
||||||
# Create dataset loader with merged data
|
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32, pretraining=True)
|
||||||
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32)
|
|
||||||
|
|
||||||
# Access the train and validation loaders
|
|
||||||
train_loader = aiia_loader.train_loader
|
train_loader = aiia_loader.train_loader
|
||||||
val_loader = aiia_loader.val_loader
|
val_loader = aiia_loader.val_loader
|
||||||
|
|
||||||
# Initialize loss functions and optimizer
|
|
||||||
criterion_denoise = nn.MSELoss()
|
criterion_denoise = nn.MSELoss()
|
||||||
criterion_rotate = nn.CrossEntropyLoss()
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
@ -46,74 +35,61 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
|
|
||||||
# Training phase
|
|
||||||
model.train()
|
model.train()
|
||||||
total_train_loss = 0.0
|
total_train_loss = 0.0
|
||||||
|
denoise_losses = []
|
||||||
|
rotate_losses = []
|
||||||
|
|
||||||
for batch in train_loader:
|
for batch in train_loader:
|
||||||
images, targets, tasks = zip(*batch)
|
noisy_imgs, targets, tasks = batch
|
||||||
|
|
||||||
if device == "cuda":
|
|
||||||
images = [img.cuda() for img in images]
|
|
||||||
targets = [t.cuda() for t in targets]
|
|
||||||
|
|
||||||
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
|
targets = targets.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Process each sample individually since tasks can vary
|
outputs = model(noisy_imgs)
|
||||||
outputs = []
|
task_losses = []
|
||||||
total_loss = 0.0
|
for i, task in enumerate(tasks):
|
||||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
|
||||||
output = model(image.unsqueeze(0))
|
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
loss = criterion_denoise(output.squeeze(), target)
|
loss = criterion_denoise(outputs[i], targets[i])
|
||||||
elif task == 'rotate':
|
denoise_losses.append(loss.item())
|
||||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
else:
|
||||||
|
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
||||||
total_loss += loss
|
rotate_losses.append(loss.item())
|
||||||
outputs.append(output)
|
task_losses.append(loss)
|
||||||
|
|
||||||
avg_loss = total_loss / len(images)
|
batch_loss = sum(task_losses) / len(task_losses)
|
||||||
avg_loss.backward()
|
batch_loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
total_train_loss += avg_loss.item()
|
total_train_loss += batch_loss.item()
|
||||||
# Separate losses for reporting (you'd need to track this based on tasks)
|
|
||||||
|
|
||||||
avg_total_train_loss = total_train_loss / len(train_loader)
|
avg_total_train_loss = total_train_loss / len(train_loader)
|
||||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
||||||
|
|
||||||
# Validation phase
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
val_losses = []
|
val_losses = []
|
||||||
for batch in val_loader:
|
for batch in val_loader:
|
||||||
images, targets, tasks = zip(*batch)
|
noisy_imgs, targets, tasks = batch
|
||||||
|
|
||||||
if device == "cuda":
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
images = [img.cuda() for img in images]
|
targets = targets.to(device)
|
||||||
targets = [t.cuda() for t in targets]
|
|
||||||
|
outputs = model(noisy_imgs)
|
||||||
outputs = []
|
|
||||||
total_loss = 0.0
|
task_losses = []
|
||||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
for i, task in enumerate(tasks):
|
||||||
output = model(image.unsqueeze(0))
|
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
loss = criterion_denoise(output.squeeze(), target)
|
loss = criterion_denoise(outputs[i], targets[i])
|
||||||
elif task == 'rotate':
|
else:
|
||||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
||||||
|
task_losses.append(loss)
|
||||||
total_loss += loss
|
|
||||||
outputs.append(output)
|
batch_loss = sum(task_losses) / len(task_losses)
|
||||||
|
val_losses.append(batch_loss.item())
|
||||||
avg_val_loss = total_loss / len(images)
|
|
||||||
val_losses.append(avg_val_loss.item())
|
|
||||||
|
|
||||||
avg_val_loss = sum(val_losses) / len(val_loader)
|
avg_val_loss = sum(val_losses) / len(val_loader)
|
||||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||||
|
|
||||||
# Save the best model
|
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
model.save("BASEv0.1")
|
model.save("BASEv0.1")
|
||||||
|
@ -122,4 +98,4 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
||||||
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||||
pretrain_model(data_path1, data_path2, num_epochs=8)
|
pretrain_model(data_path1, data_path2, num_epochs=3)
|
Loading…
Reference in New Issue