diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..6925c8f --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +include LICENSE +include README.md +include requirements.txt +recursive-include src/aiia * \ No newline at end of file diff --git a/README.md b/README.md index 0d888a0..6f149b1 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,30 @@ # AIIA + +## Example Usage: +```Python +from aiia.model import AIIABase +from aiia.model.config import AIIAConfig +from aiia.pretrain import Pretrainer + +# Create your model +config = AIIAConfig(model_name="AIIA-Base-512x20k") +model = AIIABase(config) + +# Initialize pretrainer with the model +pretrainer = Pretrainer(model, learning_rate=1e-4) + +# List of dataset paths +dataset_paths = [ + "/path/to/dataset1.parquet", + "/path/to/dataset2.parquet" +] + +# Start training with multiple datasets +pretrainer.train( + dataset_paths=dataset_paths, + num_epochs=10, + batch_size=2, + sample_size=10000 +) +``` \ No newline at end of file diff --git a/example.py b/example.py new file mode 100644 index 0000000..6d605ca --- /dev/null +++ b/example.py @@ -0,0 +1,27 @@ +data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet" +data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet" + +from aiia.model import AIIABase +from aiia.model.config import AIIAConfig +from aiia.pretrain import Pretrainer + +# Create your model +config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, hidden_size=256) +model = AIIABase(config) + +# Initialize pretrainer with the model +pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config) + +# List of dataset paths +dataset_paths = [ + data_path1, + data_path2 +] + +# Start training with multiple datasets +pretrainer.train( + dataset_paths=dataset_paths, + num_epochs=10, + batch_size=2, + sample_size=10000 +) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a8bdbe9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.black] +line-length = 88 +target-version = ['py37'] +include = '\.pyi?$' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..06e8438 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch>=4.5.0 +numpy +tqdm +pytest +pillow \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..fb20e63 --- /dev/null +++ b/run.py @@ -0,0 +1,27 @@ +data_path1 = "/root/training_data/vision-dataset/images_pretrain.parquet" +data_path2 = "/root/training_data/vision-dataset/vector_img_pretrain.parquet" + +from aiia.model import AIIABase +from aiia.model.config import AIIAConfig +from aiia.pretrain import Pretrainer + +# Create your model +config = AIIAConfig(model_name="AIIA-Base-512x20k") +model = AIIABase(config) + +# Initialize pretrainer with the model +pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config) + +# List of dataset paths +dataset_paths = [ + data_path1, + data_path2 +] + +# Start training with multiple datasets +pretrainer.train( + dataset_paths=dataset_paths, + num_epochs=10, + batch_size=2, + sample_size=10000 +) \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..fb45363 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,26 @@ +[metadata] +name = aiia +version = 0.1.0 +author = Your Name +author_email = falko.habel@gmx.de +description = AIIA deep learning model implementation +long_description = file: README.md +long_description_content_type = text/markdown +url = https://gitea.fabelous.app/Maschine-Learning/AIIA.git +classifiers = + Programming Language :: Python :: 3 + License :: OSI Approved :: MIT License + Operating System :: OS Independent + +[options] +package_dir = + = src +packages = find: +python_requires = >=3.7 +install_requires = + torch>=1.8.0 + numpy>=1.19.0 + tqdm>=4.62.0 + +[options.packages.find] +where = src \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..0eb6be6 --- /dev/null +++ b/setup.py @@ -0,0 +1,25 @@ +from setuptools import setup, find_packages + +setup( + name="aiia", + version="0.1.0", + packages=find_packages(where="src"), + package_dir={"": "src"}, + install_requires=[ + "torch>=1.8.0", + "numpy>=1.19.0", + "tqdm>=4.62.0", + ], + author="Falko Habel", + author_email="falko.habel@gmx.de", + description="AIIA deep learning model implementation", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://gitea.fabelous.app/Maschine-Learning/AIIA.git", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Creative Commons Attribution-NonCommercial 4.0 International", + "Operating System :: OS Independent", + ], + python_requires=">=3.10", +) diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py new file mode 100644 index 0000000..6a27146 --- /dev/null +++ b/src/aiia/__init__.py @@ -0,0 +1,7 @@ +from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAExpert, AIIAmoe, AIIA, AIIArecursive +from .model.config import AIIAConfig +from .data.DataLoader import DataLoader +from .pretrain.pretrainer import Pretrainer, ProjectionHead + + +__version__ = "0.1.0" diff --git a/src/aiia/data/DataLoader.py b/src/aiia/data/DataLoader.py new file mode 100644 index 0000000..4ba5032 --- /dev/null +++ b/src/aiia/data/DataLoader.py @@ -0,0 +1,228 @@ +import io +from PIL import Image +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +import random +import re +import base64 + +class FilePathLoader: + def __init__(self, dataset, file_path_column="file_path", label_column=None): + self.dataset = dataset + self.file_path_column = file_path_column + self.label_column = label_column + self.successful_count = 0 + self.skipped_count = 0 + + if self.file_path_column not in dataset.column_names: + raise ValueError(f"Column '{self.file_path_column}' not found in dataset.") + + def _get_image(self, item): + try: + path = item[self.file_path_column] + image = Image.open(path) + if image.mode == 'RGBA': + background = Image.new('RGB', image.size, (0, 0, 0)) + background.paste(image, mask=image.split()[3]) + image = background + elif image.mode != 'RGB': + image = image.convert('RGB') + return image + except Exception as e: + print(f"Error loading image from {path}: {e}") + return None + + def get_item(self, idx): + item = self.dataset.iloc[idx] + image = self._get_image(item) + if image is not None: + self.successful_count += 1 + if self.label_column is not None: + label = item.get(self.label_column) + return (image, label) + else: + return (image,) + else: + self.skipped_count += 1 + return None + + def print_summary(self): + print(f"Successfully converted {self.successful_count} images.") + print(f"Skipped {self.skipped_count} images due to errors.") + +class JPGImageLoader: + def __init__(self, dataset, bytes_column="jpg", label_column=None): + self.dataset = dataset + self.bytes_column = bytes_column + self.label_column = label_column + self.successful_count = 0 + self.skipped_count = 0 + + if self.bytes_column not in dataset.columns: + raise ValueError(f"Column '{self.bytes_column}' not found in dataset.") + + def _get_image(self, item): + try: + data = item[self.bytes_column] + + if isinstance(data, str) and data.startswith("b'"): + cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1') + bytes_data = cleaned_data + elif isinstance(data, str): + bytes_data = base64.b64decode(data) + else: + bytes_data = data + + img_bytes = io.BytesIO(bytes_data) + image = Image.open(img_bytes) + if image.mode == 'RGBA': + background = Image.new('RGB', image.size, (0, 0, 0)) + background.paste(image, mask=image.split()[3]) + image = background + elif image.mode != 'RGB': + image = image.convert('RGB') + return image + except Exception as e: + print(f"Error loading image from bytes: {e}") + return None + + def get_item(self, idx): + item = self.dataset.iloc[idx] + image = self._get_image(item) + if image is not None: + self.successful_count += 1 + if self.label_column is not None: + label = item.get(self.label_column) + return (image, label) + else: + return (image,) + else: + self.skipped_count += 1 + return None + + def print_summary(self): + print(f"Successfully converted {self.successful_count} images.") + print(f"Skipped {self.skipped_count} images due to errors.") + +class AIIADataLoader: + def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs): + self.batch_size = batch_size + self.val_split = val_split + self.seed = seed + self.pretraining = pretraining + random.seed(seed) + + sample_value = dataset[column].iloc[0] + is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and ( + isinstance(sample_value, bytes) or + sample_value.startswith("b'") or + sample_value.startswith(('b"', 'data:image')) + ) + + if is_bytes_or_bytestring: + self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column) + else: + sample_paths = dataset[column].dropna().head(1).astype(str) + filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$' + + if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths): + self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column) + else: + self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column) + + self.items = [] + for idx in range(len(dataset)): + item = self.loader.get_item(idx) + if item is not None: # Only add valid items + if self.pretraining: + img = item[0] if isinstance(item, tuple) else item + self.items.append((img, 'denoise', img)) + self.items.append((img, 'rotate', 0)) + else: + self.items.append(item) + + if not self.items: + raise ValueError("No valid items were loaded from the dataset") + + + train_indices, val_indices = self._split_data() + + self.train_dataset = self._create_subset(train_indices) + self.val_dataset = self._create_subset(val_indices) + + self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs) + self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs) + + def _split_data(self): + if len(self.items) == 0: + raise ValueError("No items to split") + + num_samples = len(self.items) + indices = list(range(num_samples)) + random.shuffle(indices) + + split_idx = int((1 - self.val_split) * num_samples) + train_indices = indices[:split_idx] + val_indices = indices[split_idx:] + + return train_indices, val_indices + + def _create_subset(self, indices): + subset_items = [self.items[i] for i in indices] + return AIIADataset(subset_items, pretraining=self.pretraining) + +class AIIADataset(torch.utils.data.Dataset): + def __init__(self, items, pretraining=False): + self.items = items + self.pretraining = pretraining + self.transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor() + ]) + + def __len__(self): + return len(self.items) + + def __getitem__(self, idx): + item = self.items[idx] + + if self.pretraining: + image, task, label = item + if not isinstance(image, Image.Image): + raise ValueError(f"Invalid image at index {idx}") + + image = self.transform(image) + if image.shape != (3, 224, 224): + raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") + + if task == 'denoise': + noise_std = 0.1 + noisy_img = image + torch.randn_like(image) * noise_std + target = image.clone() + return noisy_img, target, task + elif task == 'rotate': + angles = [0, 90, 180, 270] + angle = random.choice(angles) + rotated_img = transforms.functional.rotate(image, angle) + target = torch.tensor(angle / 90).long() + return rotated_img, target, task + else: + raise ValueError(f"Invalid task at index {idx}: {task}") + else: + if isinstance(item, tuple) and len(item) == 2: + image, label = item + if not isinstance(image, Image.Image): + raise ValueError(f"Invalid image at index {idx}") + image = self.transform(image) + if image.shape != (3, 224, 224): + raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") + return image, label + else: + if isinstance(item, Image.Image): + image = self.transform(item) + else: + image = self.transform(item[0]) + if image.shape != (3, 224, 224): + raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") + return image diff --git a/src/aiia/data/__init__.py b/src/aiia/data/__init__.py new file mode 100644 index 0000000..5e8a93c --- /dev/null +++ b/src/aiia/data/__init__.py @@ -0,0 +1,3 @@ +from .DataLoader import AIIADataLoader + +__all__ = ["AIIADataLoader"] diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py new file mode 100644 index 0000000..f0e65ff --- /dev/null +++ b/src/aiia/model/Model.py @@ -0,0 +1,230 @@ +from .config import AIIAConfig +from torch import nn +import torch +import os +import copy + + +class AIIA(nn.Module): + def __init__(self, config: AIIAConfig, **kwargs): + super(AIIA, self).__init__() + # Create a deep copy of the configuration to avoid sharing + self.config = copy.deepcopy(config) + + # Update the config with any additional keyword arguments + for key, value in kwargs.items(): + setattr(self.config, key, value) + + def save(self, path: str): + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + torch.save(self.state_dict(), f"{path}/model.pth") + self.config.save(path) + + @classmethod + def load(cls, path): + config = AIIAConfig.load(path) + model = cls(config) + model.load_state_dict(torch.load(f"{path}/model.pth")) + return model + + +class AIIABaseShared(AIIA): + def __init__(self, config: AIIAConfig, **kwargs): + """ + Initialize the AIIABaseShared model. + + Args: + config (AIIAConfig): Configuration object containing model parameters. + **kwargs: Additional keyword arguments to override configuration settings. + """ + super().__init__(config=config, **kwargs) + + # Update configuration with new parameters if provided + self. config = copy.deepcopy(config) + + for key, value in kwargs.items(): + setattr(self.config, key, value) + + # Initialize the network components + self._initialize_network() + self._initialize_activation_andPooling() + + def _initialize_network(self): + """Initialize the shared and unique layers of the network.""" + # Create a single shared convolutional layer + self.shared_layer = nn.Conv2d( + in_channels=self.config.num_channels, + out_channels=self.config.hidden_size, + kernel_size=self.config.kernel_size, + padding=1 # Using same padding as defined in config + ) + + # Initialize the unique layers with separate weights and biases + self.unique_layers = nn.ModuleList() + current_in_channels = self.config.hidden_size + + layer = nn.Conv2d( + in_channels=current_in_channels, + out_channels=self.config.hidden_size, + kernel_size=self.config.kernel_size, + padding=1 # Using same padding as defined in config + ) + + self.unique_layers.append(layer) + + def _initialize_activation_andPooling(self): + """Initialize activation function and pooling layers.""" + # Get activation function from nn module + self.activation = getattr(nn, self.config.activation_function)() + + # Initialize max pooling layer + self.max_pool = nn.MaxPool2d( + kernel_size=1, + stride=1, + padding=1 + ) + + def forward(self, x): + """Forward pass of the network.""" + # Apply shared layer transformation + out = self.shared_layer(x) + + # Pass through activation function + out = self.activation(out) + + # Apply max pooling + out = self.max_pool(out) + + # Pass through unique layers + for unique_layer in self.unique_layers: + out = unique_layer(out) + out = self.activation(out) + out = self.max_pool(out) + + return out + +class AIIABase(AIIA): + def __init__(self, config: AIIAConfig, **kwargs): + super().__init__(config=config, **kwargs) + self.config = self.config + + # Initialize layers based on configuration + layers = [] + in_channels = self.config.num_channels + + for _ in range(self.config.num_hidden_layers): + layers.extend([ + nn.Conv2d(in_channels, self.config.hidden_size, + kernel_size=self.config.kernel_size, padding=1), + getattr(nn, self.config.activation_function)(), + nn.MaxPool2d(kernel_size=1, stride=1) + ]) + in_channels = self.config.hidden_size + + self.cnn = nn.Sequential(*layers) + + def forward(self, x): + return self.cnn(x) + +class AIIAExpert(AIIA): + def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs): + super().__init__(config=config, **kwargs) + self.config = self.config + + # Initialize base CNN with configuration and chosen base class + if issubclass(base_class, AIIABase): + self.base_cnn = AIIABase(self.config, **kwargs) + elif issubclass(base_class, AIIABaseShared): + self.base_cnn = AIIABaseShared(self.config, **kwargs) + else: + raise ValueError("Invalid base class") + +class AIIAmoe(AIIA): + def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs): + super().__init__(config=config, **kwargs) + self.config = self.config + + # Update config with new parameters if provided + self.config.num_experts = num_experts + + # Initialize multiple experts using chosen base class + self.experts = nn.ModuleList([ + AIIAExpert(self.config, base_class=base_class, **kwargs) + for _ in range(self.config.num_experts) + ]) + + # Create gating network + self.gate = nn.Sequential( + nn.Linear(self.config.hidden_size, self.config.num_experts), + nn.Softmax(dim=1) + ) + + def forward(self, x): + expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) + gate_weights = self.gate(torch.mean(expert_outputs, (2, 3))) + merged_output = torch.sum( + expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1 + ) + return merged_output + +class AIIAchunked(AIIA): + def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs): + super().__init__(config=config, **kwargs) + self.config = self.config + + # Update config with new parameters if provided + self.config.patch_size = patch_size + + # Initialize base CNN for processing each patch using the specified base class + if issubclass(base_class, AIIABase): + self.base_cnn = AIIABase(self.config, **kwargs) + elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared + self.base_cnn = AIIABaseShared(self.config, **kwargs) + else: + raise ValueError("Invalid base class") + + def forward(self, x): + patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) + patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size) + patch_outputs = [] + + for p in torch.split(patches, 1, dim=2): + p = p.squeeze(2) + po = self.base_cnn(p) + patch_outputs.append(po) + + combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0) + return combined_output + +class AIIArecursive(AIIA): + def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs): + + super().__init__(config=config, **kwargs) + self.config = self.config + + # Pass recursion_depth as a kwarg to the config + self.config.recursion_depth = recursion_depth + + # Initialize chunked CNN with updated config + self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs) + + def forward(self, x, depth=0): + if depth == self.recursion_depth: + return self.chunked_cnn(x) + else: + patches = x.unfold(2, 16, 16).unfold(3, 16, 16) + patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16) + processed_patches = [] + + for p in torch.split(patches, 1, dim=2): + p = p.squeeze(2) + pp = self.forward(p, depth + 1) + processed_patches.append(pp) + + combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0) + return combined_output + +config = AIIAConfig() +model = AIIAmoe(config, num_experts=5) +model.save("test") \ No newline at end of file diff --git a/src/aiia/model/__init__.py b/src/aiia/model/__init__.py new file mode 100644 index 0000000..f68a42a --- /dev/null +++ b/src/aiia/model/__init__.py @@ -0,0 +1,21 @@ +from .Model import ( + AIIA, + AIIABase, + AIIABaseShared, + AIIAchunked, + AIIAExpert, + AIIAmoe, + AIIArecursive +) +from .config import AIIAConfig + +__all__ = [ + "AIIA", + "AIIABase", + "AIIABaseShared", + "AIIAchunked", + "AIIAExpert", + "AIIAmoe", + "AIIArecursive", + "AIIAConfig" +] \ No newline at end of file diff --git a/src/aiia/model/config.py b/src/aiia/model/config.py new file mode 100644 index 0000000..02bc709 --- /dev/null +++ b/src/aiia/model/config.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import json +import os + + +class AIIAConfig: + def __init__( + self, + model_name: str = "AIIA", + kernel_size: int = 3, + activation_function: str = "GELU", + hidden_size: int = 512, + num_hidden_layers: int = 12, + num_channels: int = 3, + learning_rate: float = 5e-5, + **kwargs + ): + self.model_name = model_name + self.kernel_size = kernel_size + self.activation_function = activation_function + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_channels = num_channels + self.learning_rate = learning_rate + + # Store additional keyword arguments as attributes + for key, value in kwargs.items(): + setattr(self, key, value) + + @property + def activation_function(self): + return self._activation_function + + @activation_function.setter + def activation_function(self, value): + attr = getattr(nn, value, None) + if attr is None or (not callable(attr) and not isinstance(attr, type(nn.Module))): + valid_funcs = [func for func in dir(nn) if callable(getattr(nn, func)) or isinstance(getattr(nn, func), type(nn.Module))] + raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}") + self._activation_function = value + + def save(self, file_path): + if not os.path.exists(file_path): + os.makedirs(file_path, exist_ok=True) + with open(f"{file_path}/config.json", 'w') as f: + json.dump(vars(self), f, indent=4) + + @classmethod + def load(cls, file_path): + with open(f"{file_path}/config.json", 'r') as f: + config_dict = json.load(f) + return cls(**config_dict) \ No newline at end of file diff --git a/src/aiia/pretrain/__init__.py b/src/aiia/pretrain/__init__.py new file mode 100644 index 0000000..c45cbc4 --- /dev/null +++ b/src/aiia/pretrain/__init__.py @@ -0,0 +1,3 @@ +from .pretrainer import Pretrainer, ProjectionHead + +__all__ = ["Pretrainer", "ProjectionHead"] \ No newline at end of file diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py new file mode 100644 index 0000000..30ebc92 --- /dev/null +++ b/src/aiia/pretrain/pretrainer.py @@ -0,0 +1,230 @@ +import torch +from torch import nn +import csv +import pandas as pd +from tqdm import tqdm +from ..model.Model import AIIA +from ..model.config import AIIAConfig +from ..data.DataLoader import AIIADataLoader + + +class ProjectionHead(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1) + self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees + + def forward(self, x, task='denoise'): + if task == 'denoise': + return self.conv_denoise(x) + else: + return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task + +class Pretrainer: + def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): + """ + Initialize the pretrainer with a model. + + Args: + model (AIIA): The model instance to pretrain + learning_rate (float): Learning rate for optimization + config (dict): Model configuration containing hidden_size + """ + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = model.to(self.device) + hidden_size = config.hidden_size + self.projection_head = ProjectionHead(hidden_size).to(self.device) + self.optimizer = torch.optim.AdamW( + list(self.model.parameters()) + list(self.projection_head.parameters()), + lr=learning_rate + ) + self.train_losses = [] + self.val_losses = [] + + @staticmethod + def safe_collate(batch): + """Safely collate batch data handling both denoise and rotate tasks.""" + denoise_batch = [] + rotate_batch = [] + + for sample in batch: + try: + noisy_img, target, task = sample + if task == 'denoise': + denoise_batch.append({ + 'image': noisy_img, + 'target': target, + 'task': task + }) + else: # rotate task + rotate_batch.append({ + 'image': noisy_img, + 'target': target, + 'task': task + }) + except Exception as e: + print(f"Skipping sample due to error: {e}") + continue + + if not denoise_batch and not rotate_batch: + return None + + batch_data = { + 'denoise': None, + 'rotate': None + } + + if denoise_batch: + images = torch.stack([x['image'] for x in denoise_batch]) + targets = torch.stack([x['target'] for x in denoise_batch]) + batch_data['denoise'] = (images, targets) + + if rotate_batch: + images = torch.stack([x['image'] for x in rotate_batch]) + targets = torch.stack([x['target'] for x in rotate_batch]) + batch_data['rotate'] = (images, targets) + + return batch_data + + def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True): + """Process a single batch of data.""" + batch_loss = 0 + + if batch_data['denoise'] is not None: + noisy_imgs, targets = batch_data['denoise'] + noisy_imgs = noisy_imgs.to(self.device) + targets = targets.to(self.device) + + features = self.model(noisy_imgs) + outputs = self.projection_head(features, task='denoise') + loss = criterion_denoise(outputs, targets) + batch_loss += loss + + if batch_data['rotate'] is not None: + imgs, targets = batch_data['rotate'] + imgs = imgs.to(self.device) + targets = targets.long().to(self.device) + + features = self.model(imgs) + outputs = self.projection_head(features, task='rotate') + loss = criterion_rotate(outputs, targets) + batch_loss += loss + + return batch_loss + + def train(self, dataset_paths, column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): + """ + Train the model using multiple specified datasets. + + Args: + dataset_paths (list): List of paths to parquet datasets + num_epochs (int): Number of training epochs + batch_size (int): Batch size for training + sample_size (int): Number of samples to use from each dataset + """ + if not dataset_paths: + raise ValueError("No dataset paths provided") + + # Read and merge all datasets + dataframes = [] + for path in dataset_paths: + try: + df = pd.read_parquet(path).head(sample_size) + dataframes.append(df) + except Exception as e: + print(f"Error loading dataset {path}: {e}") + + if not dataframes: + raise ValueError("No valid datasets could be loaded") + + merged_df = pd.concat(dataframes, ignore_index=True) + + # Initialize data loader + aiia_loader = AIIADataLoader( + merged_df, + column=column, + batch_size=batch_size, + pretraining=True, + collate_fn=self.safe_collate + ) + + criterion_denoise = nn.MSELoss() + criterion_rotate = nn.CrossEntropyLoss() + best_val_loss = float('inf') + + for epoch in range(num_epochs): + print(f"\nEpoch {epoch+1}/{num_epochs}") + print("-" * 20) + + # Training phase + self.model.train() + self.projection_head.train() + total_train_loss = 0.0 + batch_count = 0 + + for batch_data in tqdm(aiia_loader.train_loader): + if batch_data is None: + continue + + self.optimizer.zero_grad() + batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) + + if batch_loss > 0: + batch_loss.backward() + self.optimizer.step() + total_train_loss += batch_loss.item() + batch_count += 1 + + avg_train_loss = total_train_loss / max(batch_count, 1) + self.train_losses.append(avg_train_loss) + print(f"Training Loss: {avg_train_loss:.4f}") + + # Validation phase + self.model.eval() + self.projection_head.eval() + val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) + + if val_loss < best_val_loss: + best_val_loss = val_loss + self.model.save("AIIA-base-512") + print("Best model saved!") + + self.save_losses('losses.csv') + + def _validate(self, val_loader, criterion_denoise, criterion_rotate): + """Perform validation and return average validation loss.""" + val_loss = 0.0 + val_batch_count = 0 + + with torch.no_grad(): + for batch_data in val_loader: + if batch_data is None: + continue + + batch_loss = self._process_batch( + batch_data, criterion_denoise, criterion_rotate, training=False + ) + + if batch_loss > 0: + val_loss += batch_loss.item() + val_batch_count += 1 + + avg_val_loss = val_loss / max(val_batch_count, 1) + self.val_losses.append(avg_val_loss) + print(f"Validation Loss: {avg_val_loss:.4f}") + return avg_val_loss + + + def save_losses(self, csv_file): + """Save training and validation losses to a CSV file.""" + data = list(zip( + range(1, len(self.train_losses) + 1), + self.train_losses, + self.val_losses + )) + + with open(csv_file, mode='w', newline='') as file: + writer = csv.writer(file) + writer.writerow(['Epoch', 'Train Loss', 'Validation Loss']) + writer.writerows(data) + print(f"Loss data has been written to {csv_file}") \ No newline at end of file