Merge pull request 'develop' (#5) from develop into main
Reviewed-on: Fabel/AIIA#5
This commit is contained in:
commit
a2909002a7
|
@ -0,0 +1,4 @@
|
||||||
|
include LICENSE
|
||||||
|
include README.md
|
||||||
|
include requirements.txt
|
||||||
|
recursive-include src/aiia *
|
28
README.md
28
README.md
|
@ -1,2 +1,30 @@
|
||||||
# AIIA
|
# AIIA
|
||||||
|
|
||||||
|
|
||||||
|
## Example Usage:
|
||||||
|
```Python
|
||||||
|
from aiia.model import AIIABase
|
||||||
|
from aiia.model.config import AIIAConfig
|
||||||
|
from aiia.pretrain import Pretrainer
|
||||||
|
|
||||||
|
# Create your model
|
||||||
|
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
||||||
|
model = AIIABase(config)
|
||||||
|
|
||||||
|
# Initialize pretrainer with the model
|
||||||
|
pretrainer = Pretrainer(model, learning_rate=1e-4)
|
||||||
|
|
||||||
|
# List of dataset paths
|
||||||
|
dataset_paths = [
|
||||||
|
"/path/to/dataset1.parquet",
|
||||||
|
"/path/to/dataset2.parquet"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start training with multiple datasets
|
||||||
|
pretrainer.train(
|
||||||
|
dataset_paths=dataset_paths,
|
||||||
|
num_epochs=10,
|
||||||
|
batch_size=2,
|
||||||
|
sample_size=10000
|
||||||
|
)
|
||||||
|
```
|
|
@ -0,0 +1,27 @@
|
||||||
|
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
||||||
|
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||||
|
|
||||||
|
from aiia.model import AIIABase
|
||||||
|
from aiia.model.config import AIIAConfig
|
||||||
|
from aiia.pretrain import Pretrainer
|
||||||
|
|
||||||
|
# Create your model
|
||||||
|
config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, hidden_size=256)
|
||||||
|
model = AIIABase(config)
|
||||||
|
|
||||||
|
# Initialize pretrainer with the model
|
||||||
|
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
|
||||||
|
|
||||||
|
# List of dataset paths
|
||||||
|
dataset_paths = [
|
||||||
|
data_path1,
|
||||||
|
data_path2
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start training with multiple datasets
|
||||||
|
pretrainer.train(
|
||||||
|
dataset_paths=dataset_paths,
|
||||||
|
num_epochs=10,
|
||||||
|
batch_size=2,
|
||||||
|
sample_size=10000
|
||||||
|
)
|
|
@ -0,0 +1,8 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=42", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 88
|
||||||
|
target-version = ['py37']
|
||||||
|
include = '\.pyi?$'
|
|
@ -0,0 +1,5 @@
|
||||||
|
torch>=4.5.0
|
||||||
|
numpy
|
||||||
|
tqdm
|
||||||
|
pytest
|
||||||
|
pillow
|
|
@ -0,0 +1,27 @@
|
||||||
|
data_path1 = "/root/training_data/vision-dataset/images_pretrain.parquet"
|
||||||
|
data_path2 = "/root/training_data/vision-dataset/vector_img_pretrain.parquet"
|
||||||
|
|
||||||
|
from aiia.model import AIIABase
|
||||||
|
from aiia.model.config import AIIAConfig
|
||||||
|
from aiia.pretrain import Pretrainer
|
||||||
|
|
||||||
|
# Create your model
|
||||||
|
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
||||||
|
model = AIIABase(config)
|
||||||
|
|
||||||
|
# Initialize pretrainer with the model
|
||||||
|
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
|
||||||
|
|
||||||
|
# List of dataset paths
|
||||||
|
dataset_paths = [
|
||||||
|
data_path1,
|
||||||
|
data_path2
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start training with multiple datasets
|
||||||
|
pretrainer.train(
|
||||||
|
dataset_paths=dataset_paths,
|
||||||
|
num_epochs=10,
|
||||||
|
batch_size=2,
|
||||||
|
sample_size=10000
|
||||||
|
)
|
|
@ -0,0 +1,26 @@
|
||||||
|
[metadata]
|
||||||
|
name = aiia
|
||||||
|
version = 0.1.0
|
||||||
|
author = Your Name
|
||||||
|
author_email = falko.habel@gmx.de
|
||||||
|
description = AIIA deep learning model implementation
|
||||||
|
long_description = file: README.md
|
||||||
|
long_description_content_type = text/markdown
|
||||||
|
url = https://gitea.fabelous.app/Maschine-Learning/AIIA.git
|
||||||
|
classifiers =
|
||||||
|
Programming Language :: Python :: 3
|
||||||
|
License :: OSI Approved :: MIT License
|
||||||
|
Operating System :: OS Independent
|
||||||
|
|
||||||
|
[options]
|
||||||
|
package_dir =
|
||||||
|
= src
|
||||||
|
packages = find:
|
||||||
|
python_requires = >=3.7
|
||||||
|
install_requires =
|
||||||
|
torch>=1.8.0
|
||||||
|
numpy>=1.19.0
|
||||||
|
tqdm>=4.62.0
|
||||||
|
|
||||||
|
[options.packages.find]
|
||||||
|
where = src
|
|
@ -0,0 +1,25 @@
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="aiia",
|
||||||
|
version="0.1.0",
|
||||||
|
packages=find_packages(where="src"),
|
||||||
|
package_dir={"": "src"},
|
||||||
|
install_requires=[
|
||||||
|
"torch>=1.8.0",
|
||||||
|
"numpy>=1.19.0",
|
||||||
|
"tqdm>=4.62.0",
|
||||||
|
],
|
||||||
|
author="Falko Habel",
|
||||||
|
author_email="falko.habel@gmx.de",
|
||||||
|
description="AIIA deep learning model implementation",
|
||||||
|
long_description=open("README.md").read(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
url="https://gitea.fabelous.app/Maschine-Learning/AIIA.git",
|
||||||
|
classifiers=[
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"License :: OSI Approved :: Creative Commons Attribution-NonCommercial 4.0 International",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
],
|
||||||
|
python_requires=">=3.10",
|
||||||
|
)
|
|
@ -0,0 +1,7 @@
|
||||||
|
from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAExpert, AIIAmoe, AIIA, AIIArecursive
|
||||||
|
from .model.config import AIIAConfig
|
||||||
|
from .data.DataLoader import DataLoader
|
||||||
|
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
|
@ -0,0 +1,228 @@
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
|
||||||
|
class FilePathLoader:
|
||||||
|
def __init__(self, dataset, file_path_column="file_path", label_column=None):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.file_path_column = file_path_column
|
||||||
|
self.label_column = label_column
|
||||||
|
self.successful_count = 0
|
||||||
|
self.skipped_count = 0
|
||||||
|
|
||||||
|
if self.file_path_column not in dataset.column_names:
|
||||||
|
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
||||||
|
|
||||||
|
def _get_image(self, item):
|
||||||
|
try:
|
||||||
|
path = item[self.file_path_column]
|
||||||
|
image = Image.open(path)
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
background = Image.new('RGB', image.size, (0, 0, 0))
|
||||||
|
background.paste(image, mask=image.split()[3])
|
||||||
|
image = background
|
||||||
|
elif image.mode != 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading image from {path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_item(self, idx):
|
||||||
|
item = self.dataset.iloc[idx]
|
||||||
|
image = self._get_image(item)
|
||||||
|
if image is not None:
|
||||||
|
self.successful_count += 1
|
||||||
|
if self.label_column is not None:
|
||||||
|
label = item.get(self.label_column)
|
||||||
|
return (image, label)
|
||||||
|
else:
|
||||||
|
return (image,)
|
||||||
|
else:
|
||||||
|
self.skipped_count += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def print_summary(self):
|
||||||
|
print(f"Successfully converted {self.successful_count} images.")
|
||||||
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
|
||||||
|
class JPGImageLoader:
|
||||||
|
def __init__(self, dataset, bytes_column="jpg", label_column=None):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.bytes_column = bytes_column
|
||||||
|
self.label_column = label_column
|
||||||
|
self.successful_count = 0
|
||||||
|
self.skipped_count = 0
|
||||||
|
|
||||||
|
if self.bytes_column not in dataset.columns:
|
||||||
|
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
||||||
|
|
||||||
|
def _get_image(self, item):
|
||||||
|
try:
|
||||||
|
data = item[self.bytes_column]
|
||||||
|
|
||||||
|
if isinstance(data, str) and data.startswith("b'"):
|
||||||
|
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
||||||
|
bytes_data = cleaned_data
|
||||||
|
elif isinstance(data, str):
|
||||||
|
bytes_data = base64.b64decode(data)
|
||||||
|
else:
|
||||||
|
bytes_data = data
|
||||||
|
|
||||||
|
img_bytes = io.BytesIO(bytes_data)
|
||||||
|
image = Image.open(img_bytes)
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
background = Image.new('RGB', image.size, (0, 0, 0))
|
||||||
|
background.paste(image, mask=image.split()[3])
|
||||||
|
image = background
|
||||||
|
elif image.mode != 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading image from bytes: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_item(self, idx):
|
||||||
|
item = self.dataset.iloc[idx]
|
||||||
|
image = self._get_image(item)
|
||||||
|
if image is not None:
|
||||||
|
self.successful_count += 1
|
||||||
|
if self.label_column is not None:
|
||||||
|
label = item.get(self.label_column)
|
||||||
|
return (image, label)
|
||||||
|
else:
|
||||||
|
return (image,)
|
||||||
|
else:
|
||||||
|
self.skipped_count += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def print_summary(self):
|
||||||
|
print(f"Successfully converted {self.successful_count} images.")
|
||||||
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
|
||||||
|
class AIIADataLoader:
|
||||||
|
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.val_split = val_split
|
||||||
|
self.seed = seed
|
||||||
|
self.pretraining = pretraining
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
sample_value = dataset[column].iloc[0]
|
||||||
|
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
||||||
|
isinstance(sample_value, bytes) or
|
||||||
|
sample_value.startswith("b'") or
|
||||||
|
sample_value.startswith(('b"', 'data:image'))
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_bytes_or_bytestring:
|
||||||
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
|
else:
|
||||||
|
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||||
|
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
||||||
|
|
||||||
|
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
||||||
|
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
||||||
|
else:
|
||||||
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
|
|
||||||
|
self.items = []
|
||||||
|
for idx in range(len(dataset)):
|
||||||
|
item = self.loader.get_item(idx)
|
||||||
|
if item is not None: # Only add valid items
|
||||||
|
if self.pretraining:
|
||||||
|
img = item[0] if isinstance(item, tuple) else item
|
||||||
|
self.items.append((img, 'denoise', img))
|
||||||
|
self.items.append((img, 'rotate', 0))
|
||||||
|
else:
|
||||||
|
self.items.append(item)
|
||||||
|
|
||||||
|
if not self.items:
|
||||||
|
raise ValueError("No valid items were loaded from the dataset")
|
||||||
|
|
||||||
|
|
||||||
|
train_indices, val_indices = self._split_data()
|
||||||
|
|
||||||
|
self.train_dataset = self._create_subset(train_indices)
|
||||||
|
self.val_dataset = self._create_subset(val_indices)
|
||||||
|
|
||||||
|
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
||||||
|
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
||||||
|
|
||||||
|
def _split_data(self):
|
||||||
|
if len(self.items) == 0:
|
||||||
|
raise ValueError("No items to split")
|
||||||
|
|
||||||
|
num_samples = len(self.items)
|
||||||
|
indices = list(range(num_samples))
|
||||||
|
random.shuffle(indices)
|
||||||
|
|
||||||
|
split_idx = int((1 - self.val_split) * num_samples)
|
||||||
|
train_indices = indices[:split_idx]
|
||||||
|
val_indices = indices[split_idx:]
|
||||||
|
|
||||||
|
return train_indices, val_indices
|
||||||
|
|
||||||
|
def _create_subset(self, indices):
|
||||||
|
subset_items = [self.items[i] for i in indices]
|
||||||
|
return AIIADataset(subset_items, pretraining=self.pretraining)
|
||||||
|
|
||||||
|
class AIIADataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, items, pretraining=False):
|
||||||
|
self.items = items
|
||||||
|
self.pretraining = pretraining
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.items)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = self.items[idx]
|
||||||
|
|
||||||
|
if self.pretraining:
|
||||||
|
image, task, label = item
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
|
||||||
|
image = self.transform(image)
|
||||||
|
if image.shape != (3, 224, 224):
|
||||||
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||||
|
|
||||||
|
if task == 'denoise':
|
||||||
|
noise_std = 0.1
|
||||||
|
noisy_img = image + torch.randn_like(image) * noise_std
|
||||||
|
target = image.clone()
|
||||||
|
return noisy_img, target, task
|
||||||
|
elif task == 'rotate':
|
||||||
|
angles = [0, 90, 180, 270]
|
||||||
|
angle = random.choice(angles)
|
||||||
|
rotated_img = transforms.functional.rotate(image, angle)
|
||||||
|
target = torch.tensor(angle / 90).long()
|
||||||
|
return rotated_img, target, task
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid task at index {idx}: {task}")
|
||||||
|
else:
|
||||||
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
|
image, label = item
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
image = self.transform(image)
|
||||||
|
if image.shape != (3, 224, 224):
|
||||||
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||||
|
return image, label
|
||||||
|
else:
|
||||||
|
if isinstance(item, Image.Image):
|
||||||
|
image = self.transform(item)
|
||||||
|
else:
|
||||||
|
image = self.transform(item[0])
|
||||||
|
if image.shape != (3, 224, 224):
|
||||||
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||||
|
return image
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .DataLoader import AIIADataLoader
|
||||||
|
|
||||||
|
__all__ = ["AIIADataLoader"]
|
|
@ -0,0 +1,230 @@
|
||||||
|
from .config import AIIAConfig
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class AIIA(nn.Module):
|
||||||
|
def __init__(self, config: AIIAConfig, **kwargs):
|
||||||
|
super(AIIA, self).__init__()
|
||||||
|
# Create a deep copy of the configuration to avoid sharing
|
||||||
|
self.config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
# Update the config with any additional keyword arguments
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self.config, key, value)
|
||||||
|
|
||||||
|
def save(self, path: str):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
torch.save(self.state_dict(), f"{path}/model.pth")
|
||||||
|
self.config.save(path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path):
|
||||||
|
config = AIIAConfig.load(path)
|
||||||
|
model = cls(config)
|
||||||
|
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class AIIABaseShared(AIIA):
|
||||||
|
def __init__(self, config: AIIAConfig, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize the AIIABaseShared model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (AIIAConfig): Configuration object containing model parameters.
|
||||||
|
**kwargs: Additional keyword arguments to override configuration settings.
|
||||||
|
"""
|
||||||
|
super().__init__(config=config, **kwargs)
|
||||||
|
|
||||||
|
# Update configuration with new parameters if provided
|
||||||
|
self. config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self.config, key, value)
|
||||||
|
|
||||||
|
# Initialize the network components
|
||||||
|
self._initialize_network()
|
||||||
|
self._initialize_activation_andPooling()
|
||||||
|
|
||||||
|
def _initialize_network(self):
|
||||||
|
"""Initialize the shared and unique layers of the network."""
|
||||||
|
# Create a single shared convolutional layer
|
||||||
|
self.shared_layer = nn.Conv2d(
|
||||||
|
in_channels=self.config.num_channels,
|
||||||
|
out_channels=self.config.hidden_size,
|
||||||
|
kernel_size=self.config.kernel_size,
|
||||||
|
padding=1 # Using same padding as defined in config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the unique layers with separate weights and biases
|
||||||
|
self.unique_layers = nn.ModuleList()
|
||||||
|
current_in_channels = self.config.hidden_size
|
||||||
|
|
||||||
|
layer = nn.Conv2d(
|
||||||
|
in_channels=current_in_channels,
|
||||||
|
out_channels=self.config.hidden_size,
|
||||||
|
kernel_size=self.config.kernel_size,
|
||||||
|
padding=1 # Using same padding as defined in config
|
||||||
|
)
|
||||||
|
|
||||||
|
self.unique_layers.append(layer)
|
||||||
|
|
||||||
|
def _initialize_activation_andPooling(self):
|
||||||
|
"""Initialize activation function and pooling layers."""
|
||||||
|
# Get activation function from nn module
|
||||||
|
self.activation = getattr(nn, self.config.activation_function)()
|
||||||
|
|
||||||
|
# Initialize max pooling layer
|
||||||
|
self.max_pool = nn.MaxPool2d(
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass of the network."""
|
||||||
|
# Apply shared layer transformation
|
||||||
|
out = self.shared_layer(x)
|
||||||
|
|
||||||
|
# Pass through activation function
|
||||||
|
out = self.activation(out)
|
||||||
|
|
||||||
|
# Apply max pooling
|
||||||
|
out = self.max_pool(out)
|
||||||
|
|
||||||
|
# Pass through unique layers
|
||||||
|
for unique_layer in self.unique_layers:
|
||||||
|
out = unique_layer(out)
|
||||||
|
out = self.activation(out)
|
||||||
|
out = self.max_pool(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class AIIABase(AIIA):
|
||||||
|
def __init__(self, config: AIIAConfig, **kwargs):
|
||||||
|
super().__init__(config=config, **kwargs)
|
||||||
|
self.config = self.config
|
||||||
|
|
||||||
|
# Initialize layers based on configuration
|
||||||
|
layers = []
|
||||||
|
in_channels = self.config.num_channels
|
||||||
|
|
||||||
|
for _ in range(self.config.num_hidden_layers):
|
||||||
|
layers.extend([
|
||||||
|
nn.Conv2d(in_channels, self.config.hidden_size,
|
||||||
|
kernel_size=self.config.kernel_size, padding=1),
|
||||||
|
getattr(nn, self.config.activation_function)(),
|
||||||
|
nn.MaxPool2d(kernel_size=1, stride=1)
|
||||||
|
])
|
||||||
|
in_channels = self.config.hidden_size
|
||||||
|
|
||||||
|
self.cnn = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.cnn(x)
|
||||||
|
|
||||||
|
class AIIAExpert(AIIA):
|
||||||
|
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
|
||||||
|
super().__init__(config=config, **kwargs)
|
||||||
|
self.config = self.config
|
||||||
|
|
||||||
|
# Initialize base CNN with configuration and chosen base class
|
||||||
|
if issubclass(base_class, AIIABase):
|
||||||
|
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||||
|
elif issubclass(base_class, AIIABaseShared):
|
||||||
|
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid base class")
|
||||||
|
|
||||||
|
class AIIAmoe(AIIA):
|
||||||
|
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||||
|
super().__init__(config=config, **kwargs)
|
||||||
|
self.config = self.config
|
||||||
|
|
||||||
|
# Update config with new parameters if provided
|
||||||
|
self.config.num_experts = num_experts
|
||||||
|
|
||||||
|
# Initialize multiple experts using chosen base class
|
||||||
|
self.experts = nn.ModuleList([
|
||||||
|
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
||||||
|
for _ in range(self.config.num_experts)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Create gating network
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
nn.Linear(self.config.hidden_size, self.config.num_experts),
|
||||||
|
nn.Softmax(dim=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
||||||
|
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
|
||||||
|
merged_output = torch.sum(
|
||||||
|
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1
|
||||||
|
)
|
||||||
|
return merged_output
|
||||||
|
|
||||||
|
class AIIAchunked(AIIA):
|
||||||
|
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
||||||
|
super().__init__(config=config, **kwargs)
|
||||||
|
self.config = self.config
|
||||||
|
|
||||||
|
# Update config with new parameters if provided
|
||||||
|
self.config.patch_size = patch_size
|
||||||
|
|
||||||
|
# Initialize base CNN for processing each patch using the specified base class
|
||||||
|
if issubclass(base_class, AIIABase):
|
||||||
|
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||||
|
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
|
||||||
|
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid base class")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
||||||
|
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
|
||||||
|
patch_outputs = []
|
||||||
|
|
||||||
|
for p in torch.split(patches, 1, dim=2):
|
||||||
|
p = p.squeeze(2)
|
||||||
|
po = self.base_cnn(p)
|
||||||
|
patch_outputs.append(po)
|
||||||
|
|
||||||
|
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
||||||
|
return combined_output
|
||||||
|
|
||||||
|
class AIIArecursive(AIIA):
|
||||||
|
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
|
||||||
|
|
||||||
|
super().__init__(config=config, **kwargs)
|
||||||
|
self.config = self.config
|
||||||
|
|
||||||
|
# Pass recursion_depth as a kwarg to the config
|
||||||
|
self.config.recursion_depth = recursion_depth
|
||||||
|
|
||||||
|
# Initialize chunked CNN with updated config
|
||||||
|
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x, depth=0):
|
||||||
|
if depth == self.recursion_depth:
|
||||||
|
return self.chunked_cnn(x)
|
||||||
|
else:
|
||||||
|
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
|
||||||
|
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
|
||||||
|
processed_patches = []
|
||||||
|
|
||||||
|
for p in torch.split(patches, 1, dim=2):
|
||||||
|
p = p.squeeze(2)
|
||||||
|
pp = self.forward(p, depth + 1)
|
||||||
|
processed_patches.append(pp)
|
||||||
|
|
||||||
|
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
||||||
|
return combined_output
|
||||||
|
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIAmoe(config, num_experts=5)
|
||||||
|
model.save("test")
|
|
@ -0,0 +1,21 @@
|
||||||
|
from .Model import (
|
||||||
|
AIIA,
|
||||||
|
AIIABase,
|
||||||
|
AIIABaseShared,
|
||||||
|
AIIAchunked,
|
||||||
|
AIIAExpert,
|
||||||
|
AIIAmoe,
|
||||||
|
AIIArecursive
|
||||||
|
)
|
||||||
|
from .config import AIIAConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AIIA",
|
||||||
|
"AIIABase",
|
||||||
|
"AIIABaseShared",
|
||||||
|
"AIIAchunked",
|
||||||
|
"AIIAExpert",
|
||||||
|
"AIIAmoe",
|
||||||
|
"AIIArecursive",
|
||||||
|
"AIIAConfig"
|
||||||
|
]
|
|
@ -0,0 +1,53 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class AIIAConfig:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "AIIA",
|
||||||
|
kernel_size: int = 3,
|
||||||
|
activation_function: str = "GELU",
|
||||||
|
hidden_size: int = 512,
|
||||||
|
num_hidden_layers: int = 12,
|
||||||
|
num_channels: int = 3,
|
||||||
|
learning_rate: float = 5e-5,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.activation_function = activation_function
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
|
||||||
|
# Store additional keyword arguments as attributes
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_function(self):
|
||||||
|
return self._activation_function
|
||||||
|
|
||||||
|
@activation_function.setter
|
||||||
|
def activation_function(self, value):
|
||||||
|
attr = getattr(nn, value, None)
|
||||||
|
if attr is None or (not callable(attr) and not isinstance(attr, type(nn.Module))):
|
||||||
|
valid_funcs = [func for func in dir(nn) if callable(getattr(nn, func)) or isinstance(getattr(nn, func), type(nn.Module))]
|
||||||
|
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
|
||||||
|
self._activation_function = value
|
||||||
|
|
||||||
|
def save(self, file_path):
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
with open(f"{file_path}/config.json", 'w') as f:
|
||||||
|
json.dump(vars(self), f, indent=4)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, file_path):
|
||||||
|
with open(f"{file_path}/config.json", 'r') as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
return cls(**config_dict)
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .pretrainer import Pretrainer, ProjectionHead
|
||||||
|
|
||||||
|
__all__ = ["Pretrainer", "ProjectionHead"]
|
|
@ -0,0 +1,230 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import csv
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from ..model.Model import AIIA
|
||||||
|
from ..model.config import AIIAConfig
|
||||||
|
from ..data.DataLoader import AIIADataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectionHead(nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
|
||||||
|
self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
|
||||||
|
|
||||||
|
def forward(self, x, task='denoise'):
|
||||||
|
if task == 'denoise':
|
||||||
|
return self.conv_denoise(x)
|
||||||
|
else:
|
||||||
|
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||||
|
|
||||||
|
class Pretrainer:
|
||||||
|
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
|
||||||
|
"""
|
||||||
|
Initialize the pretrainer with a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (AIIA): The model instance to pretrain
|
||||||
|
learning_rate (float): Learning rate for optimization
|
||||||
|
config (dict): Model configuration containing hidden_size
|
||||||
|
"""
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.model = model.to(self.device)
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
self.projection_head = ProjectionHead(hidden_size).to(self.device)
|
||||||
|
self.optimizer = torch.optim.AdamW(
|
||||||
|
list(self.model.parameters()) + list(self.projection_head.parameters()),
|
||||||
|
lr=learning_rate
|
||||||
|
)
|
||||||
|
self.train_losses = []
|
||||||
|
self.val_losses = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def safe_collate(batch):
|
||||||
|
"""Safely collate batch data handling both denoise and rotate tasks."""
|
||||||
|
denoise_batch = []
|
||||||
|
rotate_batch = []
|
||||||
|
|
||||||
|
for sample in batch:
|
||||||
|
try:
|
||||||
|
noisy_img, target, task = sample
|
||||||
|
if task == 'denoise':
|
||||||
|
denoise_batch.append({
|
||||||
|
'image': noisy_img,
|
||||||
|
'target': target,
|
||||||
|
'task': task
|
||||||
|
})
|
||||||
|
else: # rotate task
|
||||||
|
rotate_batch.append({
|
||||||
|
'image': noisy_img,
|
||||||
|
'target': target,
|
||||||
|
'task': task
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Skipping sample due to error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not denoise_batch and not rotate_batch:
|
||||||
|
return None
|
||||||
|
|
||||||
|
batch_data = {
|
||||||
|
'denoise': None,
|
||||||
|
'rotate': None
|
||||||
|
}
|
||||||
|
|
||||||
|
if denoise_batch:
|
||||||
|
images = torch.stack([x['image'] for x in denoise_batch])
|
||||||
|
targets = torch.stack([x['target'] for x in denoise_batch])
|
||||||
|
batch_data['denoise'] = (images, targets)
|
||||||
|
|
||||||
|
if rotate_batch:
|
||||||
|
images = torch.stack([x['image'] for x in rotate_batch])
|
||||||
|
targets = torch.stack([x['target'] for x in rotate_batch])
|
||||||
|
batch_data['rotate'] = (images, targets)
|
||||||
|
|
||||||
|
return batch_data
|
||||||
|
|
||||||
|
def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True):
|
||||||
|
"""Process a single batch of data."""
|
||||||
|
batch_loss = 0
|
||||||
|
|
||||||
|
if batch_data['denoise'] is not None:
|
||||||
|
noisy_imgs, targets = batch_data['denoise']
|
||||||
|
noisy_imgs = noisy_imgs.to(self.device)
|
||||||
|
targets = targets.to(self.device)
|
||||||
|
|
||||||
|
features = self.model(noisy_imgs)
|
||||||
|
outputs = self.projection_head(features, task='denoise')
|
||||||
|
loss = criterion_denoise(outputs, targets)
|
||||||
|
batch_loss += loss
|
||||||
|
|
||||||
|
if batch_data['rotate'] is not None:
|
||||||
|
imgs, targets = batch_data['rotate']
|
||||||
|
imgs = imgs.to(self.device)
|
||||||
|
targets = targets.long().to(self.device)
|
||||||
|
|
||||||
|
features = self.model(imgs)
|
||||||
|
outputs = self.projection_head(features, task='rotate')
|
||||||
|
loss = criterion_rotate(outputs, targets)
|
||||||
|
batch_loss += loss
|
||||||
|
|
||||||
|
return batch_loss
|
||||||
|
|
||||||
|
def train(self, dataset_paths, column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000):
|
||||||
|
"""
|
||||||
|
Train the model using multiple specified datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_paths (list): List of paths to parquet datasets
|
||||||
|
num_epochs (int): Number of training epochs
|
||||||
|
batch_size (int): Batch size for training
|
||||||
|
sample_size (int): Number of samples to use from each dataset
|
||||||
|
"""
|
||||||
|
if not dataset_paths:
|
||||||
|
raise ValueError("No dataset paths provided")
|
||||||
|
|
||||||
|
# Read and merge all datasets
|
||||||
|
dataframes = []
|
||||||
|
for path in dataset_paths:
|
||||||
|
try:
|
||||||
|
df = pd.read_parquet(path).head(sample_size)
|
||||||
|
dataframes.append(df)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading dataset {path}: {e}")
|
||||||
|
|
||||||
|
if not dataframes:
|
||||||
|
raise ValueError("No valid datasets could be loaded")
|
||||||
|
|
||||||
|
merged_df = pd.concat(dataframes, ignore_index=True)
|
||||||
|
|
||||||
|
# Initialize data loader
|
||||||
|
aiia_loader = AIIADataLoader(
|
||||||
|
merged_df,
|
||||||
|
column=column,
|
||||||
|
batch_size=batch_size,
|
||||||
|
pretraining=True,
|
||||||
|
collate_fn=self.safe_collate
|
||||||
|
)
|
||||||
|
|
||||||
|
criterion_denoise = nn.MSELoss()
|
||||||
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
|
print("-" * 20)
|
||||||
|
|
||||||
|
# Training phase
|
||||||
|
self.model.train()
|
||||||
|
self.projection_head.train()
|
||||||
|
total_train_loss = 0.0
|
||||||
|
batch_count = 0
|
||||||
|
|
||||||
|
for batch_data in tqdm(aiia_loader.train_loader):
|
||||||
|
if batch_data is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
if batch_loss > 0:
|
||||||
|
batch_loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
total_train_loss += batch_loss.item()
|
||||||
|
batch_count += 1
|
||||||
|
|
||||||
|
avg_train_loss = total_train_loss / max(batch_count, 1)
|
||||||
|
self.train_losses.append(avg_train_loss)
|
||||||
|
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||||
|
|
||||||
|
# Validation phase
|
||||||
|
self.model.eval()
|
||||||
|
self.projection_head.eval()
|
||||||
|
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
self.model.save("AIIA-base-512")
|
||||||
|
print("Best model saved!")
|
||||||
|
|
||||||
|
self.save_losses('losses.csv')
|
||||||
|
|
||||||
|
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
||||||
|
"""Perform validation and return average validation loss."""
|
||||||
|
val_loss = 0.0
|
||||||
|
val_batch_count = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_data in val_loader:
|
||||||
|
if batch_data is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
batch_loss = self._process_batch(
|
||||||
|
batch_data, criterion_denoise, criterion_rotate, training=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_loss > 0:
|
||||||
|
val_loss += batch_loss.item()
|
||||||
|
val_batch_count += 1
|
||||||
|
|
||||||
|
avg_val_loss = val_loss / max(val_batch_count, 1)
|
||||||
|
self.val_losses.append(avg_val_loss)
|
||||||
|
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||||
|
return avg_val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def save_losses(self, csv_file):
|
||||||
|
"""Save training and validation losses to a CSV file."""
|
||||||
|
data = list(zip(
|
||||||
|
range(1, len(self.train_losses) + 1),
|
||||||
|
self.train_losses,
|
||||||
|
self.val_losses
|
||||||
|
))
|
||||||
|
|
||||||
|
with open(csv_file, mode='w', newline='') as file:
|
||||||
|
writer = csv.writer(file)
|
||||||
|
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
|
||||||
|
writer.writerows(data)
|
||||||
|
print(f"Loss data has been written to {csv_file}")
|
Loading…
Reference in New Issue