Merge pull request 'improved_shared_cnn' (#3) from improved_shared_cnn into develop

Reviewed-on: Fabel/AIIA#3
This commit is contained in:
Falko Victor Habel 2025-01-28 10:18:45 +00:00
commit 1e79a93a5e
6 changed files with 301 additions and 206 deletions

View File

@ -1,4 +1,3 @@
# Import submodules from .model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared
from .model import AIIA, AIIAEncoder
from .data import AIIADataLoader from .data import AIIADataLoader
from .model.config import AIIAConfig from .model.config import AIIAConfig

View File

@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
import random import random
import re import re
import base64
class FilePathLoader: class FilePathLoader:
def __init__(self, dataset, file_path_column="file_path", label_column=None): def __init__(self, dataset, file_path_column="file_path", label_column=None):
@ -21,14 +21,20 @@ class FilePathLoader:
def _get_image(self, item): def _get_image(self, item):
try: try:
path = item[self.file_path_column] path = item[self.file_path_column]
image = Image.open(path).convert("RGB") image = Image.open(path)
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (0, 0, 0))
background.paste(image, mask=image.split()[3])
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
return image return image
except Exception as e: except Exception as e:
print(f"Error loading image from {path}: {e}") print(f"Error loading image from {path}: {e}")
return None return None
def get_item(self, idx): def get_item(self, idx):
item = self.dataset[idx] item = self.dataset.iloc[idx]
image = self._get_image(item) image = self._get_image(item)
if image is not None: if image is not None:
self.successful_count += 1 self.successful_count += 1
@ -53,21 +59,36 @@ class JPGImageLoader:
self.successful_count = 0 self.successful_count = 0
self.skipped_count = 0 self.skipped_count = 0
if self.bytes_column not in dataset.column_names: if self.bytes_column not in dataset.columns:
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.") raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
def _get_image(self, item): def _get_image(self, item):
try: try:
bytes_data = item[self.bytes_column] data = item[self.bytes_column]
if isinstance(data, str) and data.startswith("b'"):
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
bytes_data = cleaned_data
elif isinstance(data, str):
bytes_data = base64.b64decode(data)
else:
bytes_data = data
img_bytes = io.BytesIO(bytes_data) img_bytes = io.BytesIO(bytes_data)
image = Image.open(img_bytes).convert("RGB") image = Image.open(img_bytes)
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (0, 0, 0))
background.paste(image, mask=image.split()[3])
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
return image return image
except Exception as e: except Exception as e:
print(f"Error loading image from bytes: {e}") print(f"Error loading image from bytes: {e}")
return None return None
def get_item(self, idx): def get_item(self, idx):
item = self.dataset[idx] item = self.dataset.iloc[idx]
image = self._get_image(item) image = self._get_image(item)
if image is not None: if image is not None:
self.successful_count += 1 self.successful_count += 1
@ -83,124 +104,125 @@ class JPGImageLoader:
def print_summary(self): def print_summary(self):
print(f"Successfully converted {self.successful_count} images.") print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.") print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader(DataLoader): class AIIADataLoader:
def __init__(self, dataset, def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None):
super().__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.val_split = val_split self.val_split = val_split
self.seed = seed self.seed = seed
self.pretraining = pretraining
random.seed(seed)
# Determine which loader to use based on the dataset's content sample_value = dataset[column].iloc[0]
# Check if any entry in bytes_column is a bytes or bytestring type is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
is_bytes_or_bytestring = any( isinstance(sample_value, bytes) or
isinstance(value, (bytes, memoryview)) sample_value.startswith("b'") or
for value in dataset[column].dropna().head(1).astype(str) sample_value.startswith(('b"', 'data:image'))
) )
if is_bytes_or_bytestring: if is_bytes_or_bytestring:
self.loader = JPGImageLoader( self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
dataset,
bytes_column=column,
label_column=label_column
)
else: else:
# Check if file_path column contains valid image file paths (at least one entry)
sample_paths = dataset[column].dropna().head(1).astype(str) sample_paths = dataset[column].dropna().head(1).astype(str)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
# Regex pattern for matching image file paths (adjust as needed) if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$' self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
if any(
re.match(filepath_pattern, path, flags=re.IGNORECASE)
for path in sample_paths
):
self.loader = FilePathLoader(
dataset,
file_path_column=column,
label_column=label_column
)
else: else:
# If neither condition is met, default to JPGImageLoader (assuming bytes are stored as strings) self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
self.loader = JPGImageLoader(
dataset,
bytes_column=column,
label_column=label_column
)
# Get all items self.items = []
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))] for idx in range(len(dataset)):
item = self.loader.get_item(idx)
if item is not None: # Only add valid items
if self.pretraining:
img = item[0] if isinstance(item, tuple) else item
self.items.append((img, 'denoise', img))
self.items.append((img, 'rotate', 0))
else:
self.items.append(item)
# Split into train and validation sets if not self.items:
raise ValueError("No valid items were loaded from the dataset")
train_indices, val_indices = self._split_data() train_indices, val_indices = self._split_data()
# Create datasets for training and validation
self.train_dataset = self._create_subset(train_indices) self.train_dataset = self._create_subset(train_indices)
self.val_dataset = self._create_subset(val_indices) self.val_dataset = self._create_subset(val_indices)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
def _split_data(self): def _split_data(self):
if len(self.items) == 0: if len(self.items) == 0:
return [], [] raise ValueError("No items to split")
tasks = [item[1] if len(item) > 1 and hasattr(item, '__getitem__') else None for item in self.items] num_samples = len(self.items)
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else [] indices = list(range(num_samples))
random.shuffle(indices)
train_indices = [] split_idx = int((1 - self.val_split) * num_samples)
val_indices = [] train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
for task in unique_tasks:
task_indices = [i for i, t in enumerate(tasks) if t == task]
n_val = int(len(task_indices) * self.val_split)
random.shuffle(task_indices)
val_indices.extend(task_indices[:n_val])
train_indices.extend(task_indices[n_val:])
return train_indices, val_indices return train_indices, val_indices
def _create_subset(self, indices): def _create_subset(self, indices):
subset_items = [self.items[i] for i in indices] subset_items = [self.items[i] for i in indices]
return AIIADataset(subset_items) return AIIADataset(subset_items, pretraining=self.pretraining)
class AIIADataset(torch.utils.data.Dataset): class AIIADataset(torch.utils.data.Dataset):
def __init__(self, items): def __init__(self, items, pretraining=False):
self.items = items self.items = items
self.pretraining = pretraining
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def __len__(self): def __len__(self):
return len(self.items) return len(self.items)
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.items[idx] item = self.items[idx]
if isinstance(item, tuple) and len(item) == 2:
image, label = item if self.pretraining:
return (image, label)
elif isinstance(item, tuple) and len(item) == 3:
image, task, label = item image, task, label = item
# Handle tasks accordingly (e.g., apply different augmentations) if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise': if task == 'denoise':
noise_std = 0.1 noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std noisy_img = image + torch.randn_like(image) * noise_std
target = image target = image.clone()
return (noisy_img, target, task) return noisy_img, target, task
elif task == 'rotate': elif task == 'rotate':
angles = [0, 90, 180, 270] angles = [0, 90, 180, 270]
angle = random.choice(angles) angle = random.choice(angles)
rotated_img = transforms.functional.rotate(image, angle) rotated_img = transforms.functional.rotate(image, angle)
target = torch.tensor(angle).long() target = torch.tensor(angle / 90).long()
return (rotated_img, target, task) return rotated_img, target, task
else: else:
raise ValueError(f"Unknown task: {task}") raise ValueError(f"Invalid task at index {idx}: {task}")
else: else:
# Handle single images without labels or tasks if isinstance(item, tuple) and len(item) == 2:
if isinstance(item, Image.Image): image, label = item
return item if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label
else: else:
raise ValueError("Invalid item format.") if isinstance(item, Image.Image):
image = self.transform(item)
else:
image = self.transform(item[0])
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image

View File

@ -1,8 +1,9 @@
from config import AIIAConfig from .config import AIIAConfig
from torch import nn from torch import nn
import torch import torch
import os import os
import copy # Add this for deep copying import copy
class AIIA(nn.Module): class AIIA(nn.Module):
def __init__(self, config: AIIAConfig, **kwargs): def __init__(self, config: AIIAConfig, **kwargs):
@ -79,8 +80,9 @@ class AIIABaseShared(AIIA):
# Initialize max pooling layer # Initialize max pooling layer
self.max_pool = nn.MaxPool2d( self.max_pool = nn.MaxPool2d(
kernel_size=self.config.kernel_size, kernel_size=1,
padding=1 # Using same padding as in Conv2d layers stride=1,
padding=1
) )
def forward(self, x): def forward(self, x):
@ -116,7 +118,7 @@ class AIIABase(AIIA):
nn.Conv2d(in_channels, self.config.hidden_size, nn.Conv2d(in_channels, self.config.hidden_size,
kernel_size=self.config.kernel_size, padding=1), kernel_size=self.config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(), getattr(nn, self.config.activation_function)(),
nn.MaxPool2d(kernel_size=2) nn.MaxPool2d(kernel_size=1, stride=1)
]) ])
in_channels = self.config.hidden_size in_channels = self.config.hidden_size
@ -221,9 +223,4 @@ class AIIArecursive(AIIA):
processed_patches.append(pp) processed_patches.append(pp)
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0) combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output return combined_output
config = AIIAConfig()
model2 = AIIABaseShared(config)
model2.save("shared")

View File

@ -1,2 +1,2 @@
from .config import AIIAConfig from .config import AIIAConfig
from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIAresursive from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared

View File

@ -8,7 +8,7 @@ class AIIAConfig:
def __init__( def __init__(
self, self,
model_name: str = "AIIA", model_name: str = "AIIA",
kernel_size: int = 5, kernel_size: int = 3,
activation_function: str = "GELU", activation_function: str = "GELU",
hidden_size: int = 512, hidden_size: int = 512,
num_hidden_layers: int = 12, num_hidden_layers: int = 12,

View File

@ -1,149 +1,226 @@
import torch import torch
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader import csv
from torchvision import transforms
from PIL import Image
import os
import random
import pandas as pd import pandas as pd
from aiia.model.config import AIIAConfig from aiia.model.config import AIIAConfig
from aiia.model import AIIABase from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader from aiia.data.DataLoader import AIIADataLoader
from tqdm import tqdm
class ProjectionHead(nn.Module):
def __init__(self):
super().__init__()
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
def pretrain_model(data_path1, data_path2, num_epochs=3): def pretrain_model(data_path1, data_path2, num_epochs=3):
# Merge the two parquet files # Read and merge datasets
df1 = pd.read_parquet(data_path1) df1 = pd.read_parquet(data_path1).head(10000)
df2 = pd.read_parquet(data_path2) df2 = pd.read_parquet(data_path2).head(10000)
merged_df = pd.concat([df1, df2], ignore_index=True) merged_df = pd.concat([df1, df2], ignore_index=True)
# Create a new AIIAConfig instance # Model configuration
config = AIIAConfig( config = AIIAConfig(
model_name="AIIA-512x", model_name="AIIA-Base-512x20k",
hidden_size=512,
num_hidden_layers=12,
kernel_size=5,
learning_rate=5e-5
) )
# Initialize the base model # Initialize model and projection head
model = AIIABase(config) model = AIIABase(config)
projection_head = ProjectionHead()
# Create dataset loader with merged data
train_dataset = AIIADataLoader(
merged_df,
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None
)
# Create separate dataloaders for training and validation sets
train_dataloader = DataLoader(
train_dataset.train_dataset,
batch_size=train_dataset.batch_size,
shuffle=True,
num_workers=4
)
val_dataloader = DataLoader(
train_dataset.val_ataset,
batch_size=train_dataset.batch_size,
shuffle=False,
num_workers=4
)
# Initialize loss functions and optimizer
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device) model.to(device)
projection_head.to(device)
def safe_collate(batch):
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
aiia_loader = AIIADataLoader(
merged_df,
column="image_bytes",
batch_size=2,
pretraining=True,
collate_fn=safe_collate
)
train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
# Update optimizer to include projection head parameters
optimizer = torch.optim.AdamW(
list(model.parameters()) + list(projection_head.parameters()),
lr=config.learning_rate
)
best_val_loss = float('inf') best_val_loss = float('inf')
train_losses = []
val_losses = []
for epoch in range(num_epochs): for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}") print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20) print("-" * 20)
# Training phase # Training phase
model.train() model.train()
projection_head.train()
total_train_loss = 0.0 total_train_loss = 0.0
denoise_train_loss = 0.0 batch_count = 0
rotate_train_loss = 0.0
for batch_data in tqdm(train_loader):
for batch in train_dataloader: if batch_data is None:
images, targets, tasks = zip(*batch) continue
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
optimizer.zero_grad() optimizer.zero_grad()
batch_loss = 0
# Process each sample individually since tasks can vary
outputs = [] # Handle denoise task
total_loss = 0.0 if batch_data['denoise'] is not None:
for i, (image, target, task) in enumerate(zip(images, targets, tasks)): noisy_imgs, targets = batch_data['denoise']
output = model(image.unsqueeze(0)) noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
if task == 'denoise': # Get features from base model
loss = criterion_denoise(output.squeeze(), target) features = model(noisy_imgs)
elif task == 'rotate': # Project features back to image space
loss = criterion_rotate(output.view(-1, len(set(outputs))), target) outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
# Handle rotate task
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
total_loss += loss # Get features from base model
outputs.append(output) features = model(imgs)
# Project features to rotation predictions
avg_loss = total_loss / len(images) outputs = projection_head(features, task='rotate')
avg_loss.backward()
optimizer.step() loss = criterion_rotate(outputs, targets)
batch_loss += loss
total_train_loss += avg_loss.item()
# Separate losses for reporting (you'd need to track this based on tasks) if batch_loss > 0:
batch_loss.backward()
avg_total_train_loss = total_train_loss / len(train_dataloader) optimizer.step()
print(f"Training Loss: {avg_total_train_loss:.4f}") total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase # Validation phase
model.eval() model.eval()
projection_head.eval()
val_loss = 0.0
val_batch_count = 0
with torch.no_grad(): with torch.no_grad():
val_losses = [] for batch_data in val_loader:
for batch in val_dataloader: if batch_data is None:
images, targets, tasks = zip(*batch) continue
if device == "cuda": batch_loss = 0
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets] if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
outputs = [] noisy_imgs = noisy_imgs.to(device)
total_loss = 0.0 targets = targets.to(device)
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
if task == 'denoise': features = model(noisy_imgs)
loss = criterion_denoise(output.squeeze(), target) outputs = projection_head(features, task='denoise')
elif task == 'rotate': loss = criterion_denoise(outputs, targets)
loss = criterion_rotate(output.view(-1, len(set(outputs))), target) batch_loss += loss
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
total_loss += loss features = model(imgs)
outputs.append(output) outputs = projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
avg_val_loss = total_loss / len(images) batch_loss += loss
val_losses.append(avg_val_loss.item())
if batch_loss > 0:
avg_val_loss = sum(val_losses) / len(val_dataloader) val_loss += batch_loss.item()
print(f"Validation Loss: {avg_val_loss:.4f}") val_batch_count += 1
# Save the best model avg_val_loss = val_loss / max(val_batch_count, 1)
val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss: if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss best_val_loss = avg_val_loss
model.save("BASEv0.1") # Save both model and projection head
model.save("AIIA-base-512")
print("Best model saved!") print("Best model saved!")
# Prepare the data to be written to the CSV file
data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses))
# Specify the CSV file name
csv_file = 'losses.csv'
# Write the data to the CSV file
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
# Write the header
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
# Write the data
writer.writerows(data)
print(f"Data has been written to {csv_file}")
if __name__ == "__main__": if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet" data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet" data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=8) pretrain_model(data_path1, data_path2, num_epochs=10)