diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index 25b8128..71acf48 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -1,4 +1,3 @@ -# Import submodules -from .model import AIIA, AIIAEncoder +from .model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared from .data import AIIADataLoader from .model.config import AIIAConfig \ No newline at end of file diff --git a/src/aiia/data/DataLoader.py b/src/aiia/data/DataLoader.py index d3cb900..4ba5032 100644 --- a/src/aiia/data/DataLoader.py +++ b/src/aiia/data/DataLoader.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader from torchvision import transforms import random import re - +import base64 class FilePathLoader: def __init__(self, dataset, file_path_column="file_path", label_column=None): @@ -21,14 +21,20 @@ class FilePathLoader: def _get_image(self, item): try: path = item[self.file_path_column] - image = Image.open(path).convert("RGB") + image = Image.open(path) + if image.mode == 'RGBA': + background = Image.new('RGB', image.size, (0, 0, 0)) + background.paste(image, mask=image.split()[3]) + image = background + elif image.mode != 'RGB': + image = image.convert('RGB') return image except Exception as e: print(f"Error loading image from {path}: {e}") return None def get_item(self, idx): - item = self.dataset[idx] + item = self.dataset.iloc[idx] image = self._get_image(item) if image is not None: self.successful_count += 1 @@ -53,21 +59,36 @@ class JPGImageLoader: self.successful_count = 0 self.skipped_count = 0 - if self.bytes_column not in dataset.column_names: + if self.bytes_column not in dataset.columns: raise ValueError(f"Column '{self.bytes_column}' not found in dataset.") - + def _get_image(self, item): try: - bytes_data = item[self.bytes_column] + data = item[self.bytes_column] + + if isinstance(data, str) and data.startswith("b'"): + cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1') + bytes_data = cleaned_data + elif isinstance(data, str): + bytes_data = base64.b64decode(data) + else: + bytes_data = data + img_bytes = io.BytesIO(bytes_data) - image = Image.open(img_bytes).convert("RGB") + image = Image.open(img_bytes) + if image.mode == 'RGBA': + background = Image.new('RGB', image.size, (0, 0, 0)) + background.paste(image, mask=image.split()[3]) + image = background + elif image.mode != 'RGB': + image = image.convert('RGB') return image except Exception as e: print(f"Error loading image from bytes: {e}") return None def get_item(self, idx): - item = self.dataset[idx] + item = self.dataset.iloc[idx] image = self._get_image(item) if image is not None: self.successful_count += 1 @@ -83,124 +104,125 @@ class JPGImageLoader: def print_summary(self): print(f"Successfully converted {self.successful_count} images.") print(f"Skipped {self.skipped_count} images due to errors.") - -class AIIADataLoader(DataLoader): - def __init__(self, dataset, - batch_size=32, - val_split=0.2, - seed=42, - column="file_path", - label_column=None): - super().__init__() - +class AIIADataLoader: + def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs): self.batch_size = batch_size self.val_split = val_split self.seed = seed + self.pretraining = pretraining + random.seed(seed) - # Determine which loader to use based on the dataset's content - # Check if any entry in bytes_column is a bytes or bytestring type - is_bytes_or_bytestring = any( - isinstance(value, (bytes, memoryview)) - for value in dataset[column].dropna().head(1).astype(str) + sample_value = dataset[column].iloc[0] + is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and ( + isinstance(sample_value, bytes) or + sample_value.startswith("b'") or + sample_value.startswith(('b"', 'data:image')) ) if is_bytes_or_bytestring: - self.loader = JPGImageLoader( - dataset, - bytes_column=column, - label_column=label_column - ) + self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column) else: - # Check if file_path column contains valid image file paths (at least one entry) sample_paths = dataset[column].dropna().head(1).astype(str) + filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$' - # Regex pattern for matching image file paths (adjust as needed) - filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$' - - if any( - re.match(filepath_pattern, path, flags=re.IGNORECASE) - for path in sample_paths - ): - self.loader = FilePathLoader( - dataset, - file_path_column=column, - label_column=label_column - ) + if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths): + self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column) else: - # If neither condition is met, default to JPGImageLoader (assuming bytes are stored as strings) - self.loader = JPGImageLoader( - dataset, - bytes_column=column, - label_column=label_column - ) + self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column) - # Get all items - self.items = [self.loader.get_item(idx) for idx in range(len(dataset))] + self.items = [] + for idx in range(len(dataset)): + item = self.loader.get_item(idx) + if item is not None: # Only add valid items + if self.pretraining: + img = item[0] if isinstance(item, tuple) else item + self.items.append((img, 'denoise', img)) + self.items.append((img, 'rotate', 0)) + else: + self.items.append(item) - # Split into train and validation sets + if not self.items: + raise ValueError("No valid items were loaded from the dataset") + + train_indices, val_indices = self._split_data() - # Create datasets for training and validation self.train_dataset = self._create_subset(train_indices) self.val_dataset = self._create_subset(val_indices) + self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs) + self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs) + def _split_data(self): if len(self.items) == 0: - return [], [] + raise ValueError("No items to split") - tasks = [item[1] if len(item) > 1 and hasattr(item, '__getitem__') else None for item in self.items] - unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else [] + num_samples = len(self.items) + indices = list(range(num_samples)) + random.shuffle(indices) - train_indices = [] - val_indices = [] + split_idx = int((1 - self.val_split) * num_samples) + train_indices = indices[:split_idx] + val_indices = indices[split_idx:] - for task in unique_tasks: - task_indices = [i for i, t in enumerate(tasks) if t == task] - n_val = int(len(task_indices) * self.val_split) - - random.shuffle(task_indices) - - val_indices.extend(task_indices[:n_val]) - train_indices.extend(task_indices[n_val:]) - return train_indices, val_indices def _create_subset(self, indices): subset_items = [self.items[i] for i in indices] - return AIIADataset(subset_items) - + return AIIADataset(subset_items, pretraining=self.pretraining) + class AIIADataset(torch.utils.data.Dataset): - def __init__(self, items): + def __init__(self, items, pretraining=False): self.items = items + self.pretraining = pretraining + self.transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor() + ]) def __len__(self): return len(self.items) def __getitem__(self, idx): item = self.items[idx] - if isinstance(item, tuple) and len(item) == 2: - image, label = item - return (image, label) - elif isinstance(item, tuple) and len(item) == 3: + + if self.pretraining: image, task, label = item - # Handle tasks accordingly (e.g., apply different augmentations) + if not isinstance(image, Image.Image): + raise ValueError(f"Invalid image at index {idx}") + + image = self.transform(image) + if image.shape != (3, 224, 224): + raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") + if task == 'denoise': noise_std = 0.1 noisy_img = image + torch.randn_like(image) * noise_std - target = image - return (noisy_img, target, task) + target = image.clone() + return noisy_img, target, task elif task == 'rotate': angles = [0, 90, 180, 270] angle = random.choice(angles) rotated_img = transforms.functional.rotate(image, angle) - target = torch.tensor(angle).long() - return (rotated_img, target, task) + target = torch.tensor(angle / 90).long() + return rotated_img, target, task else: - raise ValueError(f"Unknown task: {task}") + raise ValueError(f"Invalid task at index {idx}: {task}") else: - # Handle single images without labels or tasks - if isinstance(item, Image.Image): - return item + if isinstance(item, tuple) and len(item) == 2: + image, label = item + if not isinstance(image, Image.Image): + raise ValueError(f"Invalid image at index {idx}") + image = self.transform(image) + if image.shape != (3, 224, 224): + raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") + return image, label else: - raise ValueError("Invalid item format.") + if isinstance(item, Image.Image): + image = self.transform(item) + else: + image = self.transform(item[0]) + if image.shape != (3, 224, 224): + raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") + return image diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index a454c8c..771caf8 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -1,8 +1,9 @@ -from config import AIIAConfig +from .config import AIIAConfig from torch import nn import torch import os -import copy # Add this for deep copying +import copy + class AIIA(nn.Module): def __init__(self, config: AIIAConfig, **kwargs): @@ -79,8 +80,9 @@ class AIIABaseShared(AIIA): # Initialize max pooling layer self.max_pool = nn.MaxPool2d( - kernel_size=self.config.kernel_size, - padding=1 # Using same padding as in Conv2d layers + kernel_size=1, + stride=1, + padding=1 ) def forward(self, x): @@ -116,7 +118,7 @@ class AIIABase(AIIA): nn.Conv2d(in_channels, self.config.hidden_size, kernel_size=self.config.kernel_size, padding=1), getattr(nn, self.config.activation_function)(), - nn.MaxPool2d(kernel_size=2) + nn.MaxPool2d(kernel_size=1, stride=1) ]) in_channels = self.config.hidden_size @@ -221,9 +223,4 @@ class AIIArecursive(AIIA): processed_patches.append(pp) combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0) - return combined_output - -config = AIIAConfig() -model2 = AIIABaseShared(config) - -model2.save("shared") \ No newline at end of file + return combined_output \ No newline at end of file diff --git a/src/aiia/model/__init__.py b/src/aiia/model/__init__.py index 5757152..0e6a459 100644 --- a/src/aiia/model/__init__.py +++ b/src/aiia/model/__init__.py @@ -1,2 +1,2 @@ from .config import AIIAConfig -from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIAresursive \ No newline at end of file +from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared \ No newline at end of file diff --git a/src/aiia/model/config.py b/src/aiia/model/config.py index e2ae83e..02bc709 100644 --- a/src/aiia/model/config.py +++ b/src/aiia/model/config.py @@ -8,7 +8,7 @@ class AIIAConfig: def __init__( self, model_name: str = "AIIA", - kernel_size: int = 5, + kernel_size: int = 3, activation_function: str = "GELU", hidden_size: int = 512, num_hidden_layers: int = 12, diff --git a/src/pretrain.py b/src/pretrain.py index daea216..02e4e7f 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -1,149 +1,226 @@ import torch from torch import nn -from torch.utils.data import Dataset, DataLoader -from torchvision import transforms -from PIL import Image -import os -import random +import csv import pandas as pd from aiia.model.config import AIIAConfig from aiia.model import AIIABase from aiia.data.DataLoader import AIIADataLoader +from tqdm import tqdm + +class ProjectionHead(nn.Module): + def __init__(self): + super().__init__() + self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1) + self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees + + def forward(self, x, task='denoise'): + if task == 'denoise': + return self.conv_denoise(x) + else: + return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task def pretrain_model(data_path1, data_path2, num_epochs=3): - # Merge the two parquet files - df1 = pd.read_parquet(data_path1) - df2 = pd.read_parquet(data_path2) + # Read and merge datasets + df1 = pd.read_parquet(data_path1).head(10000) + df2 = pd.read_parquet(data_path2).head(10000) merged_df = pd.concat([df1, df2], ignore_index=True) - # Create a new AIIAConfig instance + # Model configuration config = AIIAConfig( - model_name="AIIA-512x", - hidden_size=512, - num_hidden_layers=12, - kernel_size=5, - learning_rate=5e-5 + model_name="AIIA-Base-512x20k", ) - # Initialize the base model + # Initialize model and projection head model = AIIABase(config) - - # Create dataset loader with merged data - train_dataset = AIIADataLoader( - merged_df, - batch_size=32, - val_split=0.2, - seed=42, - column="file_path", - label_column=None - ) - - # Create separate dataloaders for training and validation sets - train_dataloader = DataLoader( - train_dataset.train_dataset, - batch_size=train_dataset.batch_size, - shuffle=True, - num_workers=4 - ) - - val_dataloader = DataLoader( - train_dataset.val_ataset, - batch_size=train_dataset.batch_size, - shuffle=False, - num_workers=4 - ) - - # Initialize loss functions and optimizer - criterion_denoise = nn.MSELoss() - criterion_rotate = nn.CrossEntropyLoss() - - optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) - + projection_head = ProjectionHead() + device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) + projection_head.to(device) + + def safe_collate(batch): + denoise_batch = [] + rotate_batch = [] + + for sample in batch: + try: + noisy_img, target, task = sample + if task == 'denoise': + denoise_batch.append({ + 'image': noisy_img, + 'target': target, + 'task': task + }) + else: # rotate task + rotate_batch.append({ + 'image': noisy_img, + 'target': target, + 'task': task + }) + except Exception as e: + print(f"Skipping sample due to error: {e}") + continue + + if not denoise_batch and not rotate_batch: + return None + + batch_data = { + 'denoise': None, + 'rotate': None + } + + if denoise_batch: + images = torch.stack([x['image'] for x in denoise_batch]) + targets = torch.stack([x['target'] for x in denoise_batch]) + batch_data['denoise'] = (images, targets) + + if rotate_batch: + images = torch.stack([x['image'] for x in rotate_batch]) + targets = torch.stack([x['target'] for x in rotate_batch]) + batch_data['rotate'] = (images, targets) + + return batch_data + + aiia_loader = AIIADataLoader( + merged_df, + column="image_bytes", + batch_size=2, + pretraining=True, + collate_fn=safe_collate + ) + + train_loader = aiia_loader.train_loader + val_loader = aiia_loader.val_loader + + criterion_denoise = nn.MSELoss() + criterion_rotate = nn.CrossEntropyLoss() + + # Update optimizer to include projection head parameters + optimizer = torch.optim.AdamW( + list(model.parameters()) + list(projection_head.parameters()), + lr=config.learning_rate + ) best_val_loss = float('inf') - + train_losses = [] + val_losses = [] for epoch in range(num_epochs): print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) # Training phase model.train() + projection_head.train() total_train_loss = 0.0 - denoise_train_loss = 0.0 - rotate_train_loss = 0.0 - - for batch in train_dataloader: - images, targets, tasks = zip(*batch) - - if device == "cuda": - images = [img.cuda() for img in images] - targets = [t.cuda() for t in targets] - + batch_count = 0 + + for batch_data in tqdm(train_loader): + if batch_data is None: + continue + optimizer.zero_grad() - - # Process each sample individually since tasks can vary - outputs = [] - total_loss = 0.0 - for i, (image, target, task) in enumerate(zip(images, targets, tasks)): - output = model(image.unsqueeze(0)) + batch_loss = 0 + + # Handle denoise task + if batch_data['denoise'] is not None: + noisy_imgs, targets = batch_data['denoise'] + noisy_imgs = noisy_imgs.to(device) + targets = targets.to(device) - if task == 'denoise': - loss = criterion_denoise(output.squeeze(), target) - elif task == 'rotate': - loss = criterion_rotate(output.view(-1, len(set(outputs))), target) + # Get features from base model + features = model(noisy_imgs) + # Project features back to image space + outputs = projection_head(features, task='denoise') + loss = criterion_denoise(outputs, targets) + batch_loss += loss + + # Handle rotate task + if batch_data['rotate'] is not None: + imgs, targets = batch_data['rotate'] + imgs = imgs.to(device) + targets = targets.long().to(device) - total_loss += loss - outputs.append(output) - - avg_loss = total_loss / len(images) - avg_loss.backward() - optimizer.step() - - total_train_loss += avg_loss.item() - # Separate losses for reporting (you'd need to track this based on tasks) - - avg_total_train_loss = total_train_loss / len(train_dataloader) - print(f"Training Loss: {avg_total_train_loss:.4f}") + # Get features from base model + features = model(imgs) + # Project features to rotation predictions + outputs = projection_head(features, task='rotate') + + loss = criterion_rotate(outputs, targets) + batch_loss += loss + + if batch_loss > 0: + batch_loss.backward() + optimizer.step() + total_train_loss += batch_loss.item() + batch_count += 1 + + avg_train_loss = total_train_loss / max(batch_count, 1) + train_losses.append(avg_train_loss) + print(f"Training Loss: {avg_train_loss:.4f}") # Validation phase model.eval() + projection_head.eval() + val_loss = 0.0 + val_batch_count = 0 + with torch.no_grad(): - val_losses = [] - for batch in val_dataloader: - images, targets, tasks = zip(*batch) - - if device == "cuda": - images = [img.cuda() for img in images] - targets = [t.cuda() for t in targets] - - outputs = [] - total_loss = 0.0 - for i, (image, target, task) in enumerate(zip(images, targets, tasks)): - output = model(image.unsqueeze(0)) + for batch_data in val_loader: + if batch_data is None: + continue + + batch_loss = 0 + + if batch_data['denoise'] is not None: + noisy_imgs, targets = batch_data['denoise'] + noisy_imgs = noisy_imgs.to(device) + targets = targets.to(device) - if task == 'denoise': - loss = criterion_denoise(output.squeeze(), target) - elif task == 'rotate': - loss = criterion_rotate(output.view(-1, len(set(outputs))), target) + features = model(noisy_imgs) + outputs = projection_head(features, task='denoise') + loss = criterion_denoise(outputs, targets) + batch_loss += loss + + if batch_data['rotate'] is not None: + imgs, targets = batch_data['rotate'] + imgs = imgs.to(device) + targets = targets.long().to(device) - total_loss += loss - outputs.append(output) - - avg_val_loss = total_loss / len(images) - val_losses.append(avg_val_loss.item()) - - avg_val_loss = sum(val_losses) / len(val_dataloader) - print(f"Validation Loss: {avg_val_loss:.4f}") - - # Save the best model + features = model(imgs) + outputs = projection_head(features, task='rotate') + loss = criterion_rotate(outputs, targets) + batch_loss += loss + + if batch_loss > 0: + val_loss += batch_loss.item() + val_batch_count += 1 + + avg_val_loss = val_loss / max(val_batch_count, 1) + val_losses.append(avg_val_loss) + print(f"Validation Loss: {avg_val_loss:.4f}") + if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss - model.save("BASEv0.1") + # Save both model and projection head + model.save("AIIA-base-512") print("Best model saved!") + # Prepare the data to be written to the CSV file + data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses)) + + # Specify the CSV file name + csv_file = 'losses.csv' + + # Write the data to the CSV file + with open(csv_file, mode='w', newline='') as file: + writer = csv.writer(file) + # Write the header + writer.writerow(['Epoch', 'Train Loss', 'Validation Loss']) + # Write the data + writer.writerows(data) + print(f"Data has been written to {csv_file}") + if __name__ == "__main__": - data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet" + data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet" data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet" - pretrain_model(data_path1, data_path2, num_epochs=8) \ No newline at end of file + pretrain_model(data_path1, data_path2, num_epochs=10) \ No newline at end of file