Merge pull request 'improved_shared_cnn' (#3) from improved_shared_cnn into develop
Reviewed-on: Fabel/AIIA#3
This commit is contained in:
commit
1e79a93a5e
|
@ -1,4 +1,3 @@
|
||||||
# Import submodules
|
from .model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared
|
||||||
from .model import AIIA, AIIAEncoder
|
|
||||||
from .data import AIIADataLoader
|
from .data import AIIADataLoader
|
||||||
from .model.config import AIIAConfig
|
from .model.config import AIIAConfig
|
|
@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import base64
|
||||||
|
|
||||||
class FilePathLoader:
|
class FilePathLoader:
|
||||||
def __init__(self, dataset, file_path_column="file_path", label_column=None):
|
def __init__(self, dataset, file_path_column="file_path", label_column=None):
|
||||||
|
@ -21,14 +21,20 @@ class FilePathLoader:
|
||||||
def _get_image(self, item):
|
def _get_image(self, item):
|
||||||
try:
|
try:
|
||||||
path = item[self.file_path_column]
|
path = item[self.file_path_column]
|
||||||
image = Image.open(path).convert("RGB")
|
image = Image.open(path)
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
background = Image.new('RGB', image.size, (0, 0, 0))
|
||||||
|
background.paste(image, mask=image.split()[3])
|
||||||
|
image = background
|
||||||
|
elif image.mode != 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
return image
|
return image
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading image from {path}: {e}")
|
print(f"Error loading image from {path}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_item(self, idx):
|
def get_item(self, idx):
|
||||||
item = self.dataset[idx]
|
item = self.dataset.iloc[idx]
|
||||||
image = self._get_image(item)
|
image = self._get_image(item)
|
||||||
if image is not None:
|
if image is not None:
|
||||||
self.successful_count += 1
|
self.successful_count += 1
|
||||||
|
@ -53,21 +59,36 @@ class JPGImageLoader:
|
||||||
self.successful_count = 0
|
self.successful_count = 0
|
||||||
self.skipped_count = 0
|
self.skipped_count = 0
|
||||||
|
|
||||||
if self.bytes_column not in dataset.column_names:
|
if self.bytes_column not in dataset.columns:
|
||||||
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
||||||
|
|
||||||
def _get_image(self, item):
|
def _get_image(self, item):
|
||||||
try:
|
try:
|
||||||
bytes_data = item[self.bytes_column]
|
data = item[self.bytes_column]
|
||||||
|
|
||||||
|
if isinstance(data, str) and data.startswith("b'"):
|
||||||
|
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
||||||
|
bytes_data = cleaned_data
|
||||||
|
elif isinstance(data, str):
|
||||||
|
bytes_data = base64.b64decode(data)
|
||||||
|
else:
|
||||||
|
bytes_data = data
|
||||||
|
|
||||||
img_bytes = io.BytesIO(bytes_data)
|
img_bytes = io.BytesIO(bytes_data)
|
||||||
image = Image.open(img_bytes).convert("RGB")
|
image = Image.open(img_bytes)
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
background = Image.new('RGB', image.size, (0, 0, 0))
|
||||||
|
background.paste(image, mask=image.split()[3])
|
||||||
|
image = background
|
||||||
|
elif image.mode != 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
return image
|
return image
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading image from bytes: {e}")
|
print(f"Error loading image from bytes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_item(self, idx):
|
def get_item(self, idx):
|
||||||
item = self.dataset[idx]
|
item = self.dataset.iloc[idx]
|
||||||
image = self._get_image(item)
|
image = self._get_image(item)
|
||||||
if image is not None:
|
if image is not None:
|
||||||
self.successful_count += 1
|
self.successful_count += 1
|
||||||
|
@ -83,124 +104,125 @@ class JPGImageLoader:
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
print(f"Successfully converted {self.successful_count} images.")
|
print(f"Successfully converted {self.successful_count} images.")
|
||||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
|
||||||
|
|
||||||
class AIIADataLoader(DataLoader):
|
class AIIADataLoader:
|
||||||
def __init__(self, dataset,
|
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
||||||
batch_size=32,
|
|
||||||
val_split=0.2,
|
|
||||||
seed=42,
|
|
||||||
column="file_path",
|
|
||||||
label_column=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.pretraining = pretraining
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
# Determine which loader to use based on the dataset's content
|
sample_value = dataset[column].iloc[0]
|
||||||
# Check if any entry in bytes_column is a bytes or bytestring type
|
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
||||||
is_bytes_or_bytestring = any(
|
isinstance(sample_value, bytes) or
|
||||||
isinstance(value, (bytes, memoryview))
|
sample_value.startswith("b'") or
|
||||||
for value in dataset[column].dropna().head(1).astype(str)
|
sample_value.startswith(('b"', 'data:image'))
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_bytes_or_bytestring:
|
if is_bytes_or_bytestring:
|
||||||
self.loader = JPGImageLoader(
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
dataset,
|
|
||||||
bytes_column=column,
|
|
||||||
label_column=label_column
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Check if file_path column contains valid image file paths (at least one entry)
|
|
||||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||||
|
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
||||||
|
|
||||||
# Regex pattern for matching image file paths (adjust as needed)
|
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
||||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
|
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
||||||
|
|
||||||
if any(
|
|
||||||
re.match(filepath_pattern, path, flags=re.IGNORECASE)
|
|
||||||
for path in sample_paths
|
|
||||||
):
|
|
||||||
self.loader = FilePathLoader(
|
|
||||||
dataset,
|
|
||||||
file_path_column=column,
|
|
||||||
label_column=label_column
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# If neither condition is met, default to JPGImageLoader (assuming bytes are stored as strings)
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
self.loader = JPGImageLoader(
|
|
||||||
dataset,
|
|
||||||
bytes_column=column,
|
|
||||||
label_column=label_column
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get all items
|
self.items = []
|
||||||
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
|
for idx in range(len(dataset)):
|
||||||
|
item = self.loader.get_item(idx)
|
||||||
|
if item is not None: # Only add valid items
|
||||||
|
if self.pretraining:
|
||||||
|
img = item[0] if isinstance(item, tuple) else item
|
||||||
|
self.items.append((img, 'denoise', img))
|
||||||
|
self.items.append((img, 'rotate', 0))
|
||||||
|
else:
|
||||||
|
self.items.append(item)
|
||||||
|
|
||||||
# Split into train and validation sets
|
if not self.items:
|
||||||
|
raise ValueError("No valid items were loaded from the dataset")
|
||||||
|
|
||||||
|
|
||||||
train_indices, val_indices = self._split_data()
|
train_indices, val_indices = self._split_data()
|
||||||
|
|
||||||
# Create datasets for training and validation
|
|
||||||
self.train_dataset = self._create_subset(train_indices)
|
self.train_dataset = self._create_subset(train_indices)
|
||||||
self.val_dataset = self._create_subset(val_indices)
|
self.val_dataset = self._create_subset(val_indices)
|
||||||
|
|
||||||
|
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
||||||
|
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
||||||
|
|
||||||
def _split_data(self):
|
def _split_data(self):
|
||||||
if len(self.items) == 0:
|
if len(self.items) == 0:
|
||||||
return [], []
|
raise ValueError("No items to split")
|
||||||
|
|
||||||
tasks = [item[1] if len(item) > 1 and hasattr(item, '__getitem__') else None for item in self.items]
|
num_samples = len(self.items)
|
||||||
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else []
|
indices = list(range(num_samples))
|
||||||
|
random.shuffle(indices)
|
||||||
|
|
||||||
train_indices = []
|
split_idx = int((1 - self.val_split) * num_samples)
|
||||||
val_indices = []
|
train_indices = indices[:split_idx]
|
||||||
|
val_indices = indices[split_idx:]
|
||||||
|
|
||||||
for task in unique_tasks:
|
|
||||||
task_indices = [i for i, t in enumerate(tasks) if t == task]
|
|
||||||
n_val = int(len(task_indices) * self.val_split)
|
|
||||||
|
|
||||||
random.shuffle(task_indices)
|
|
||||||
|
|
||||||
val_indices.extend(task_indices[:n_val])
|
|
||||||
train_indices.extend(task_indices[n_val:])
|
|
||||||
|
|
||||||
return train_indices, val_indices
|
return train_indices, val_indices
|
||||||
|
|
||||||
def _create_subset(self, indices):
|
def _create_subset(self, indices):
|
||||||
subset_items = [self.items[i] for i in indices]
|
subset_items = [self.items[i] for i in indices]
|
||||||
return AIIADataset(subset_items)
|
return AIIADataset(subset_items, pretraining=self.pretraining)
|
||||||
|
|
||||||
class AIIADataset(torch.utils.data.Dataset):
|
class AIIADataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, items):
|
def __init__(self, items, pretraining=False):
|
||||||
self.items = items
|
self.items = items
|
||||||
|
self.pretraining = pretraining
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.items)
|
return len(self.items)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.items[idx]
|
item = self.items[idx]
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
|
||||||
image, label = item
|
if self.pretraining:
|
||||||
return (image, label)
|
|
||||||
elif isinstance(item, tuple) and len(item) == 3:
|
|
||||||
image, task, label = item
|
image, task, label = item
|
||||||
# Handle tasks accordingly (e.g., apply different augmentations)
|
if not isinstance(image, Image.Image):
|
||||||
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
|
||||||
|
image = self.transform(image)
|
||||||
|
if image.shape != (3, 224, 224):
|
||||||
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
noise_std = 0.1
|
noise_std = 0.1
|
||||||
noisy_img = image + torch.randn_like(image) * noise_std
|
noisy_img = image + torch.randn_like(image) * noise_std
|
||||||
target = image
|
target = image.clone()
|
||||||
return (noisy_img, target, task)
|
return noisy_img, target, task
|
||||||
elif task == 'rotate':
|
elif task == 'rotate':
|
||||||
angles = [0, 90, 180, 270]
|
angles = [0, 90, 180, 270]
|
||||||
angle = random.choice(angles)
|
angle = random.choice(angles)
|
||||||
rotated_img = transforms.functional.rotate(image, angle)
|
rotated_img = transforms.functional.rotate(image, angle)
|
||||||
target = torch.tensor(angle).long()
|
target = torch.tensor(angle / 90).long()
|
||||||
return (rotated_img, target, task)
|
return rotated_img, target, task
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown task: {task}")
|
raise ValueError(f"Invalid task at index {idx}: {task}")
|
||||||
else:
|
else:
|
||||||
# Handle single images without labels or tasks
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
if isinstance(item, Image.Image):
|
image, label = item
|
||||||
return item
|
if not isinstance(image, Image.Image):
|
||||||
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
image = self.transform(image)
|
||||||
|
if image.shape != (3, 224, 224):
|
||||||
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||||
|
return image, label
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid item format.")
|
if isinstance(item, Image.Image):
|
||||||
|
image = self.transform(item)
|
||||||
|
else:
|
||||||
|
image = self.transform(item[0])
|
||||||
|
if image.shape != (3, 224, 224):
|
||||||
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||||
|
return image
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
from config import AIIAConfig
|
from .config import AIIAConfig
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import copy # Add this for deep copying
|
import copy
|
||||||
|
|
||||||
|
|
||||||
class AIIA(nn.Module):
|
class AIIA(nn.Module):
|
||||||
def __init__(self, config: AIIAConfig, **kwargs):
|
def __init__(self, config: AIIAConfig, **kwargs):
|
||||||
|
@ -79,8 +80,9 @@ class AIIABaseShared(AIIA):
|
||||||
|
|
||||||
# Initialize max pooling layer
|
# Initialize max pooling layer
|
||||||
self.max_pool = nn.MaxPool2d(
|
self.max_pool = nn.MaxPool2d(
|
||||||
kernel_size=self.config.kernel_size,
|
kernel_size=1,
|
||||||
padding=1 # Using same padding as in Conv2d layers
|
stride=1,
|
||||||
|
padding=1
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -116,7 +118,7 @@ class AIIABase(AIIA):
|
||||||
nn.Conv2d(in_channels, self.config.hidden_size,
|
nn.Conv2d(in_channels, self.config.hidden_size,
|
||||||
kernel_size=self.config.kernel_size, padding=1),
|
kernel_size=self.config.kernel_size, padding=1),
|
||||||
getattr(nn, self.config.activation_function)(),
|
getattr(nn, self.config.activation_function)(),
|
||||||
nn.MaxPool2d(kernel_size=2)
|
nn.MaxPool2d(kernel_size=1, stride=1)
|
||||||
])
|
])
|
||||||
in_channels = self.config.hidden_size
|
in_channels = self.config.hidden_size
|
||||||
|
|
||||||
|
@ -221,9 +223,4 @@ class AIIArecursive(AIIA):
|
||||||
processed_patches.append(pp)
|
processed_patches.append(pp)
|
||||||
|
|
||||||
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
||||||
return combined_output
|
return combined_output
|
||||||
|
|
||||||
config = AIIAConfig()
|
|
||||||
model2 = AIIABaseShared(config)
|
|
||||||
|
|
||||||
model2.save("shared")
|
|
|
@ -1,2 +1,2 @@
|
||||||
from .config import AIIAConfig
|
from .config import AIIAConfig
|
||||||
from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIAresursive
|
from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared
|
|
@ -8,7 +8,7 @@ class AIIAConfig:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = "AIIA",
|
model_name: str = "AIIA",
|
||||||
kernel_size: int = 5,
|
kernel_size: int = 3,
|
||||||
activation_function: str = "GELU",
|
activation_function: str = "GELU",
|
||||||
hidden_size: int = 512,
|
hidden_size: int = 512,
|
||||||
num_hidden_layers: int = 12,
|
num_hidden_layers: int = 12,
|
||||||
|
|
295
src/pretrain.py
295
src/pretrain.py
|
@ -1,149 +1,226 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, DataLoader
|
import csv
|
||||||
from torchvision import transforms
|
|
||||||
from PIL import Image
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from aiia.model.config import AIIAConfig
|
from aiia.model.config import AIIAConfig
|
||||||
from aiia.model import AIIABase
|
from aiia.model import AIIABase
|
||||||
from aiia.data.DataLoader import AIIADataLoader
|
from aiia.data.DataLoader import AIIADataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class ProjectionHead(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
|
||||||
|
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
|
||||||
|
|
||||||
|
def forward(self, x, task='denoise'):
|
||||||
|
if task == 'denoise':
|
||||||
|
return self.conv_denoise(x)
|
||||||
|
else:
|
||||||
|
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||||
|
|
||||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
# Merge the two parquet files
|
# Read and merge datasets
|
||||||
df1 = pd.read_parquet(data_path1)
|
df1 = pd.read_parquet(data_path1).head(10000)
|
||||||
df2 = pd.read_parquet(data_path2)
|
df2 = pd.read_parquet(data_path2).head(10000)
|
||||||
merged_df = pd.concat([df1, df2], ignore_index=True)
|
merged_df = pd.concat([df1, df2], ignore_index=True)
|
||||||
|
|
||||||
# Create a new AIIAConfig instance
|
# Model configuration
|
||||||
config = AIIAConfig(
|
config = AIIAConfig(
|
||||||
model_name="AIIA-512x",
|
model_name="AIIA-Base-512x20k",
|
||||||
hidden_size=512,
|
|
||||||
num_hidden_layers=12,
|
|
||||||
kernel_size=5,
|
|
||||||
learning_rate=5e-5
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the base model
|
# Initialize model and projection head
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
|
projection_head = ProjectionHead()
|
||||||
# Create dataset loader with merged data
|
|
||||||
train_dataset = AIIADataLoader(
|
|
||||||
merged_df,
|
|
||||||
batch_size=32,
|
|
||||||
val_split=0.2,
|
|
||||||
seed=42,
|
|
||||||
column="file_path",
|
|
||||||
label_column=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create separate dataloaders for training and validation sets
|
|
||||||
train_dataloader = DataLoader(
|
|
||||||
train_dataset.train_dataset,
|
|
||||||
batch_size=train_dataset.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=4
|
|
||||||
)
|
|
||||||
|
|
||||||
val_dataloader = DataLoader(
|
|
||||||
train_dataset.val_ataset,
|
|
||||||
batch_size=train_dataset.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=4
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize loss functions and optimizer
|
|
||||||
criterion_denoise = nn.MSELoss()
|
|
||||||
criterion_rotate = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
projection_head.to(device)
|
||||||
|
|
||||||
|
def safe_collate(batch):
|
||||||
|
denoise_batch = []
|
||||||
|
rotate_batch = []
|
||||||
|
|
||||||
|
for sample in batch:
|
||||||
|
try:
|
||||||
|
noisy_img, target, task = sample
|
||||||
|
if task == 'denoise':
|
||||||
|
denoise_batch.append({
|
||||||
|
'image': noisy_img,
|
||||||
|
'target': target,
|
||||||
|
'task': task
|
||||||
|
})
|
||||||
|
else: # rotate task
|
||||||
|
rotate_batch.append({
|
||||||
|
'image': noisy_img,
|
||||||
|
'target': target,
|
||||||
|
'task': task
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Skipping sample due to error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not denoise_batch and not rotate_batch:
|
||||||
|
return None
|
||||||
|
|
||||||
|
batch_data = {
|
||||||
|
'denoise': None,
|
||||||
|
'rotate': None
|
||||||
|
}
|
||||||
|
|
||||||
|
if denoise_batch:
|
||||||
|
images = torch.stack([x['image'] for x in denoise_batch])
|
||||||
|
targets = torch.stack([x['target'] for x in denoise_batch])
|
||||||
|
batch_data['denoise'] = (images, targets)
|
||||||
|
|
||||||
|
if rotate_batch:
|
||||||
|
images = torch.stack([x['image'] for x in rotate_batch])
|
||||||
|
targets = torch.stack([x['target'] for x in rotate_batch])
|
||||||
|
batch_data['rotate'] = (images, targets)
|
||||||
|
|
||||||
|
return batch_data
|
||||||
|
|
||||||
|
aiia_loader = AIIADataLoader(
|
||||||
|
merged_df,
|
||||||
|
column="image_bytes",
|
||||||
|
batch_size=2,
|
||||||
|
pretraining=True,
|
||||||
|
collate_fn=safe_collate
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader = aiia_loader.train_loader
|
||||||
|
val_loader = aiia_loader.val_loader
|
||||||
|
|
||||||
|
criterion_denoise = nn.MSELoss()
|
||||||
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
# Update optimizer to include projection head parameters
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
list(model.parameters()) + list(projection_head.parameters()),
|
||||||
|
lr=config.learning_rate
|
||||||
|
)
|
||||||
|
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
train_losses = []
|
||||||
|
val_losses = []
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
|
|
||||||
# Training phase
|
# Training phase
|
||||||
model.train()
|
model.train()
|
||||||
|
projection_head.train()
|
||||||
total_train_loss = 0.0
|
total_train_loss = 0.0
|
||||||
denoise_train_loss = 0.0
|
batch_count = 0
|
||||||
rotate_train_loss = 0.0
|
|
||||||
|
for batch_data in tqdm(train_loader):
|
||||||
for batch in train_dataloader:
|
if batch_data is None:
|
||||||
images, targets, tasks = zip(*batch)
|
continue
|
||||||
|
|
||||||
if device == "cuda":
|
|
||||||
images = [img.cuda() for img in images]
|
|
||||||
targets = [t.cuda() for t in targets]
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
batch_loss = 0
|
||||||
# Process each sample individually since tasks can vary
|
|
||||||
outputs = []
|
# Handle denoise task
|
||||||
total_loss = 0.0
|
if batch_data['denoise'] is not None:
|
||||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
noisy_imgs, targets = batch_data['denoise']
|
||||||
output = model(image.unsqueeze(0))
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
|
targets = targets.to(device)
|
||||||
|
|
||||||
if task == 'denoise':
|
# Get features from base model
|
||||||
loss = criterion_denoise(output.squeeze(), target)
|
features = model(noisy_imgs)
|
||||||
elif task == 'rotate':
|
# Project features back to image space
|
||||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
outputs = projection_head(features, task='denoise')
|
||||||
|
loss = criterion_denoise(outputs, targets)
|
||||||
|
batch_loss += loss
|
||||||
|
|
||||||
|
# Handle rotate task
|
||||||
|
if batch_data['rotate'] is not None:
|
||||||
|
imgs, targets = batch_data['rotate']
|
||||||
|
imgs = imgs.to(device)
|
||||||
|
targets = targets.long().to(device)
|
||||||
|
|
||||||
total_loss += loss
|
# Get features from base model
|
||||||
outputs.append(output)
|
features = model(imgs)
|
||||||
|
# Project features to rotation predictions
|
||||||
avg_loss = total_loss / len(images)
|
outputs = projection_head(features, task='rotate')
|
||||||
avg_loss.backward()
|
|
||||||
optimizer.step()
|
loss = criterion_rotate(outputs, targets)
|
||||||
|
batch_loss += loss
|
||||||
total_train_loss += avg_loss.item()
|
|
||||||
# Separate losses for reporting (you'd need to track this based on tasks)
|
if batch_loss > 0:
|
||||||
|
batch_loss.backward()
|
||||||
avg_total_train_loss = total_train_loss / len(train_dataloader)
|
optimizer.step()
|
||||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
total_train_loss += batch_loss.item()
|
||||||
|
batch_count += 1
|
||||||
|
|
||||||
|
avg_train_loss = total_train_loss / max(batch_count, 1)
|
||||||
|
train_losses.append(avg_train_loss)
|
||||||
|
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||||
|
|
||||||
# Validation phase
|
# Validation phase
|
||||||
model.eval()
|
model.eval()
|
||||||
|
projection_head.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
val_batch_count = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
val_losses = []
|
for batch_data in val_loader:
|
||||||
for batch in val_dataloader:
|
if batch_data is None:
|
||||||
images, targets, tasks = zip(*batch)
|
continue
|
||||||
|
|
||||||
if device == "cuda":
|
batch_loss = 0
|
||||||
images = [img.cuda() for img in images]
|
|
||||||
targets = [t.cuda() for t in targets]
|
if batch_data['denoise'] is not None:
|
||||||
|
noisy_imgs, targets = batch_data['denoise']
|
||||||
outputs = []
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
total_loss = 0.0
|
targets = targets.to(device)
|
||||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
|
||||||
output = model(image.unsqueeze(0))
|
|
||||||
|
|
||||||
if task == 'denoise':
|
features = model(noisy_imgs)
|
||||||
loss = criterion_denoise(output.squeeze(), target)
|
outputs = projection_head(features, task='denoise')
|
||||||
elif task == 'rotate':
|
loss = criterion_denoise(outputs, targets)
|
||||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
batch_loss += loss
|
||||||
|
|
||||||
|
if batch_data['rotate'] is not None:
|
||||||
|
imgs, targets = batch_data['rotate']
|
||||||
|
imgs = imgs.to(device)
|
||||||
|
targets = targets.long().to(device)
|
||||||
|
|
||||||
total_loss += loss
|
features = model(imgs)
|
||||||
outputs.append(output)
|
outputs = projection_head(features, task='rotate')
|
||||||
|
loss = criterion_rotate(outputs, targets)
|
||||||
avg_val_loss = total_loss / len(images)
|
batch_loss += loss
|
||||||
val_losses.append(avg_val_loss.item())
|
|
||||||
|
if batch_loss > 0:
|
||||||
avg_val_loss = sum(val_losses) / len(val_dataloader)
|
val_loss += batch_loss.item()
|
||||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
val_batch_count += 1
|
||||||
|
|
||||||
# Save the best model
|
avg_val_loss = val_loss / max(val_batch_count, 1)
|
||||||
|
val_losses.append(avg_val_loss)
|
||||||
|
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||||
|
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
model.save("BASEv0.1")
|
# Save both model and projection head
|
||||||
|
model.save("AIIA-base-512")
|
||||||
print("Best model saved!")
|
print("Best model saved!")
|
||||||
|
|
||||||
|
# Prepare the data to be written to the CSV file
|
||||||
|
data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses))
|
||||||
|
|
||||||
|
# Specify the CSV file name
|
||||||
|
csv_file = 'losses.csv'
|
||||||
|
|
||||||
|
# Write the data to the CSV file
|
||||||
|
with open(csv_file, mode='w', newline='') as file:
|
||||||
|
writer = csv.writer(file)
|
||||||
|
# Write the header
|
||||||
|
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
|
||||||
|
# Write the data
|
||||||
|
writer.writerows(data)
|
||||||
|
print(f"Data has been written to {csv_file}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet"
|
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
||||||
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||||
pretrain_model(data_path1, data_path2, num_epochs=8)
|
pretrain_model(data_path1, data_path2, num_epochs=10)
|
Loading…
Reference in New Issue