Merge pull request 'improved_shared_cnn' (#3) from improved_shared_cnn into develop

Reviewed-on: Fabel/AIIA#3
This commit is contained in:
Falko Victor Habel 2025-01-28 10:18:45 +00:00
commit 1e79a93a5e
6 changed files with 301 additions and 206 deletions

View File

@ -1,4 +1,3 @@
# Import submodules
from .model import AIIA, AIIAEncoder
from .model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared
from .data import AIIADataLoader
from .model.config import AIIAConfig

View File

@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
from torchvision import transforms
import random
import re
import base64
class FilePathLoader:
def __init__(self, dataset, file_path_column="file_path", label_column=None):
@ -21,14 +21,20 @@ class FilePathLoader:
def _get_image(self, item):
try:
path = item[self.file_path_column]
image = Image.open(path).convert("RGB")
image = Image.open(path)
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (0, 0, 0))
background.paste(image, mask=image.split()[3])
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
print(f"Error loading image from {path}: {e}")
return None
def get_item(self, idx):
item = self.dataset[idx]
item = self.dataset.iloc[idx]
image = self._get_image(item)
if image is not None:
self.successful_count += 1
@ -53,21 +59,36 @@ class JPGImageLoader:
self.successful_count = 0
self.skipped_count = 0
if self.bytes_column not in dataset.column_names:
if self.bytes_column not in dataset.columns:
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
def _get_image(self, item):
try:
bytes_data = item[self.bytes_column]
data = item[self.bytes_column]
if isinstance(data, str) and data.startswith("b'"):
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
bytes_data = cleaned_data
elif isinstance(data, str):
bytes_data = base64.b64decode(data)
else:
bytes_data = data
img_bytes = io.BytesIO(bytes_data)
image = Image.open(img_bytes).convert("RGB")
image = Image.open(img_bytes)
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (0, 0, 0))
background.paste(image, mask=image.split()[3])
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
print(f"Error loading image from bytes: {e}")
return None
def get_item(self, idx):
item = self.dataset[idx]
item = self.dataset.iloc[idx]
image = self._get_image(item)
if image is not None:
self.successful_count += 1
@ -84,123 +105,124 @@ class JPGImageLoader:
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader(DataLoader):
def __init__(self, dataset,
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None):
super().__init__()
class AIIADataLoader:
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
self.pretraining = pretraining
random.seed(seed)
# Determine which loader to use based on the dataset's content
# Check if any entry in bytes_column is a bytes or bytestring type
is_bytes_or_bytestring = any(
isinstance(value, (bytes, memoryview))
for value in dataset[column].dropna().head(1).astype(str)
sample_value = dataset[column].iloc[0]
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
isinstance(sample_value, bytes) or
sample_value.startswith("b'") or
sample_value.startswith(('b"', 'data:image'))
)
if is_bytes_or_bytestring:
self.loader = JPGImageLoader(
dataset,
bytes_column=column,
label_column=label_column
)
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
else:
# Check if file_path column contains valid image file paths (at least one entry)
sample_paths = dataset[column].dropna().head(1).astype(str)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
# Regex pattern for matching image file paths (adjust as needed)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
if any(
re.match(filepath_pattern, path, flags=re.IGNORECASE)
for path in sample_paths
):
self.loader = FilePathLoader(
dataset,
file_path_column=column,
label_column=label_column
)
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
else:
# If neither condition is met, default to JPGImageLoader (assuming bytes are stored as strings)
self.loader = JPGImageLoader(
dataset,
bytes_column=column,
label_column=label_column
)
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
self.items = []
for idx in range(len(dataset)):
item = self.loader.get_item(idx)
if item is not None: # Only add valid items
if self.pretraining:
img = item[0] if isinstance(item, tuple) else item
self.items.append((img, 'denoise', img))
self.items.append((img, 'rotate', 0))
else:
self.items.append(item)
if not self.items:
raise ValueError("No valid items were loaded from the dataset")
# Get all items
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
# Split into train and validation sets
train_indices, val_indices = self._split_data()
# Create datasets for training and validation
self.train_dataset = self._create_subset(train_indices)
self.val_dataset = self._create_subset(val_indices)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
def _split_data(self):
if len(self.items) == 0:
return [], []
raise ValueError("No items to split")
tasks = [item[1] if len(item) > 1 and hasattr(item, '__getitem__') else None for item in self.items]
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else []
num_samples = len(self.items)
indices = list(range(num_samples))
random.shuffle(indices)
train_indices = []
val_indices = []
for task in unique_tasks:
task_indices = [i for i, t in enumerate(tasks) if t == task]
n_val = int(len(task_indices) * self.val_split)
random.shuffle(task_indices)
val_indices.extend(task_indices[:n_val])
train_indices.extend(task_indices[n_val:])
split_idx = int((1 - self.val_split) * num_samples)
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
return train_indices, val_indices
def _create_subset(self, indices):
subset_items = [self.items[i] for i in indices]
return AIIADataset(subset_items)
return AIIADataset(subset_items, pretraining=self.pretraining)
class AIIADataset(torch.utils.data.Dataset):
def __init__(self, items):
def __init__(self, items, pretraining=False):
self.items = items
self.pretraining = pretraining
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
item = self.items[idx]
if isinstance(item, tuple) and len(item) == 2:
image, label = item
return (image, label)
elif isinstance(item, tuple) and len(item) == 3:
if self.pretraining:
image, task, label = item
# Handle tasks accordingly (e.g., apply different augmentations)
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise':
noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std
target = image
return (noisy_img, target, task)
target = image.clone()
return noisy_img, target, task
elif task == 'rotate':
angles = [0, 90, 180, 270]
angle = random.choice(angles)
rotated_img = transforms.functional.rotate(image, angle)
target = torch.tensor(angle).long()
return (rotated_img, target, task)
target = torch.tensor(angle / 90).long()
return rotated_img, target, task
else:
raise ValueError(f"Unknown task: {task}")
raise ValueError(f"Invalid task at index {idx}: {task}")
else:
if isinstance(item, tuple) and len(item) == 2:
image, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label
else:
# Handle single images without labels or tasks
if isinstance(item, Image.Image):
return item
image = self.transform(item)
else:
raise ValueError("Invalid item format.")
image = self.transform(item[0])
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image

View File

@ -1,8 +1,9 @@
from config import AIIAConfig
from .config import AIIAConfig
from torch import nn
import torch
import os
import copy # Add this for deep copying
import copy
class AIIA(nn.Module):
def __init__(self, config: AIIAConfig, **kwargs):
@ -79,8 +80,9 @@ class AIIABaseShared(AIIA):
# Initialize max pooling layer
self.max_pool = nn.MaxPool2d(
kernel_size=self.config.kernel_size,
padding=1 # Using same padding as in Conv2d layers
kernel_size=1,
stride=1,
padding=1
)
def forward(self, x):
@ -116,7 +118,7 @@ class AIIABase(AIIA):
nn.Conv2d(in_channels, self.config.hidden_size,
kernel_size=self.config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(),
nn.MaxPool2d(kernel_size=2)
nn.MaxPool2d(kernel_size=1, stride=1)
])
in_channels = self.config.hidden_size
@ -222,8 +224,3 @@ class AIIArecursive(AIIA):
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output
config = AIIAConfig()
model2 = AIIABaseShared(config)
model2.save("shared")

View File

@ -1,2 +1,2 @@
from .config import AIIAConfig
from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIAresursive
from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared

View File

@ -8,7 +8,7 @@ class AIIAConfig:
def __init__(
self,
model_name: str = "AIIA",
kernel_size: int = 5,
kernel_size: int = 3,
activation_function: str = "GELU",
hidden_size: int = 512,
num_hidden_layers: int = 12,

View File

@ -1,149 +1,226 @@
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import random
import csv
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader
from tqdm import tqdm
class ProjectionHead(nn.Module):
def __init__(self):
super().__init__()
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
def pretrain_model(data_path1, data_path2, num_epochs=3):
# Merge the two parquet files
df1 = pd.read_parquet(data_path1)
df2 = pd.read_parquet(data_path2)
# Read and merge datasets
df1 = pd.read_parquet(data_path1).head(10000)
df2 = pd.read_parquet(data_path2).head(10000)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Create a new AIIAConfig instance
# Model configuration
config = AIIAConfig(
model_name="AIIA-512x",
hidden_size=512,
num_hidden_layers=12,
kernel_size=5,
learning_rate=5e-5
model_name="AIIA-Base-512x20k",
)
# Initialize the base model
# Initialize model and projection head
model = AIIABase(config)
# Create dataset loader with merged data
train_dataset = AIIADataLoader(
merged_df,
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None
)
# Create separate dataloaders for training and validation sets
train_dataloader = DataLoader(
train_dataset.train_dataset,
batch_size=train_dataset.batch_size,
shuffle=True,
num_workers=4
)
val_dataloader = DataLoader(
train_dataset.val_ataset,
batch_size=train_dataset.batch_size,
shuffle=False,
num_workers=4
)
# Initialize loss functions and optimizer
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
projection_head = ProjectionHead()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
projection_head.to(device)
def safe_collate(batch):
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
aiia_loader = AIIADataLoader(
merged_df,
column="image_bytes",
batch_size=2,
pretraining=True,
collate_fn=safe_collate
)
train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
# Update optimizer to include projection head parameters
optimizer = torch.optim.AdamW(
list(model.parameters()) + list(projection_head.parameters()),
lr=config.learning_rate
)
best_val_loss = float('inf')
train_losses = []
val_losses = []
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
model.train()
projection_head.train()
total_train_loss = 0.0
denoise_train_loss = 0.0
rotate_train_loss = 0.0
batch_count = 0
for batch in train_dataloader:
images, targets, tasks = zip(*batch)
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
for batch_data in tqdm(train_loader):
if batch_data is None:
continue
optimizer.zero_grad()
batch_loss = 0
# Process each sample individually since tasks can vary
outputs = []
total_loss = 0.0
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
# Handle denoise task
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target)
elif task == 'rotate':
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
# Get features from base model
features = model(noisy_imgs)
# Project features back to image space
outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
total_loss += loss
outputs.append(output)
# Handle rotate task
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
avg_loss = total_loss / len(images)
avg_loss.backward()
# Get features from base model
features = model(imgs)
# Project features to rotation predictions
outputs = projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
batch_loss.backward()
optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
total_train_loss += avg_loss.item()
# Separate losses for reporting (you'd need to track this based on tasks)
avg_total_train_loss = total_train_loss / len(train_dataloader)
print(f"Training Loss: {avg_total_train_loss:.4f}")
avg_train_loss = total_train_loss / max(batch_count, 1)
train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
model.eval()
projection_head.eval()
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
val_losses = []
for batch in val_dataloader:
images, targets, tasks = zip(*batch)
for batch_data in val_loader:
if batch_data is None:
continue
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
batch_loss = 0
outputs = []
total_loss = 0.0
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target)
elif task == 'rotate':
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
features = model(noisy_imgs)
outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
total_loss += loss
outputs.append(output)
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
avg_val_loss = total_loss / len(images)
val_losses.append(avg_val_loss.item())
features = model(imgs)
outputs = projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
avg_val_loss = sum(val_losses) / len(val_dataloader)
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
# Save the best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("BASEv0.1")
# Save both model and projection head
model.save("AIIA-base-512")
print("Best model saved!")
# Prepare the data to be written to the CSV file
data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses))
# Specify the CSV file name
csv_file = 'losses.csv'
# Write the data to the CSV file
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
# Write the header
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
# Write the data
writer.writerows(data)
print(f"Data has been written to {csv_file}")
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet"
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=8)
pretrain_model(data_path1, data_path2, num_epochs=10)