develop #5
|
@ -0,0 +1,4 @@
|
|||
include LICENSE
|
||||
include README.md
|
||||
include requirements.txt
|
||||
recursive-include src/aiia *
|
28
README.md
28
README.md
|
@ -1,2 +1,30 @@
|
|||
# AIIA
|
||||
|
||||
|
||||
## Example Usage:
|
||||
```Python
|
||||
from aiia.model import AIIABase
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.pretrain import Pretrainer
|
||||
|
||||
# Create your model
|
||||
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
||||
model = AIIABase(config)
|
||||
|
||||
# Initialize pretrainer with the model
|
||||
pretrainer = Pretrainer(model, learning_rate=1e-4)
|
||||
|
||||
# List of dataset paths
|
||||
dataset_paths = [
|
||||
"/path/to/dataset1.parquet",
|
||||
"/path/to/dataset2.parquet"
|
||||
]
|
||||
|
||||
# Start training with multiple datasets
|
||||
pretrainer.train(
|
||||
dataset_paths=dataset_paths,
|
||||
num_epochs=10,
|
||||
batch_size=2,
|
||||
sample_size=10000
|
||||
)
|
||||
```
|
|
@ -0,0 +1,27 @@
|
|||
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
||||
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||
|
||||
from aiia.model import AIIABase
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.pretrain import Pretrainer
|
||||
|
||||
# Create your model
|
||||
config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, hidden_size=256)
|
||||
model = AIIABase(config)
|
||||
|
||||
# Initialize pretrainer with the model
|
||||
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
|
||||
|
||||
# List of dataset paths
|
||||
dataset_paths = [
|
||||
data_path1,
|
||||
data_path2
|
||||
]
|
||||
|
||||
# Start training with multiple datasets
|
||||
pretrainer.train(
|
||||
dataset_paths=dataset_paths,
|
||||
num_epochs=10,
|
||||
batch_size=2,
|
||||
sample_size=10000
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=42", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py37']
|
||||
include = '\.pyi?$'
|
|
@ -0,0 +1,5 @@
|
|||
torch>=4.5.0
|
||||
numpy
|
||||
tqdm
|
||||
pytest
|
||||
pillow
|
|
@ -0,0 +1,27 @@
|
|||
data_path1 = "/root/training_data/vision-dataset/images_pretrain.parquet"
|
||||
data_path2 = "/root/training_data/vision-dataset/vector_img_pretrain.parquet"
|
||||
|
||||
from aiia.model import AIIABase
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.pretrain import Pretrainer
|
||||
|
||||
# Create your model
|
||||
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
||||
model = AIIABase(config)
|
||||
|
||||
# Initialize pretrainer with the model
|
||||
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
|
||||
|
||||
# List of dataset paths
|
||||
dataset_paths = [
|
||||
data_path1,
|
||||
data_path2
|
||||
]
|
||||
|
||||
# Start training with multiple datasets
|
||||
pretrainer.train(
|
||||
dataset_paths=dataset_paths,
|
||||
num_epochs=10,
|
||||
batch_size=2,
|
||||
sample_size=10000
|
||||
)
|
|
@ -0,0 +1,26 @@
|
|||
[metadata]
|
||||
name = aiia
|
||||
version = 0.1.0
|
||||
author = Your Name
|
||||
author_email = falko.habel@gmx.de
|
||||
description = AIIA deep learning model implementation
|
||||
long_description = file: README.md
|
||||
long_description_content_type = text/markdown
|
||||
url = https://gitea.fabelous.app/Maschine-Learning/AIIA.git
|
||||
classifiers =
|
||||
Programming Language :: Python :: 3
|
||||
License :: OSI Approved :: MIT License
|
||||
Operating System :: OS Independent
|
||||
|
||||
[options]
|
||||
package_dir =
|
||||
= src
|
||||
packages = find:
|
||||
python_requires = >=3.7
|
||||
install_requires =
|
||||
torch>=1.8.0
|
||||
numpy>=1.19.0
|
||||
tqdm>=4.62.0
|
||||
|
||||
[options.packages.find]
|
||||
where = src
|
|
@ -0,0 +1,25 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="aiia",
|
||||
version="0.1.0",
|
||||
packages=find_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
install_requires=[
|
||||
"torch>=1.8.0",
|
||||
"numpy>=1.19.0",
|
||||
"tqdm>=4.62.0",
|
||||
],
|
||||
author="Falko Habel",
|
||||
author_email="falko.habel@gmx.de",
|
||||
description="AIIA deep learning model implementation",
|
||||
long_description=open("README.md").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://gitea.fabelous.app/Maschine-Learning/AIIA.git",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Creative Commons Attribution-NonCommercial 4.0 International",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires=">=3.10",
|
||||
)
|
|
@ -0,0 +1,7 @@
|
|||
from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAExpert, AIIAmoe, AIIA, AIIArecursive
|
||||
from .model.config import AIIAConfig
|
||||
from .data.DataLoader import DataLoader
|
||||
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||
|
||||
|
||||
__version__ = "0.1.0"
|
|
@ -0,0 +1,228 @@
|
|||
import io
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
import random
|
||||
import re
|
||||
import base64
|
||||
|
||||
class FilePathLoader:
|
||||
def __init__(self, dataset, file_path_column="file_path", label_column=None):
|
||||
self.dataset = dataset
|
||||
self.file_path_column = file_path_column
|
||||
self.label_column = label_column
|
||||
self.successful_count = 0
|
||||
self.skipped_count = 0
|
||||
|
||||
if self.file_path_column not in dataset.column_names:
|
||||
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
||||
|
||||
def _get_image(self, item):
|
||||
try:
|
||||
path = item[self.file_path_column]
|
||||
image = Image.open(path)
|
||||
if image.mode == 'RGBA':
|
||||
background = Image.new('RGB', image.size, (0, 0, 0))
|
||||
background.paste(image, mask=image.split()[3])
|
||||
image = background
|
||||
elif image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Error loading image from {path}: {e}")
|
||||
return None
|
||||
|
||||
def get_item(self, idx):
|
||||
item = self.dataset.iloc[idx]
|
||||
image = self._get_image(item)
|
||||
if image is not None:
|
||||
self.successful_count += 1
|
||||
if self.label_column is not None:
|
||||
label = item.get(self.label_column)
|
||||
return (image, label)
|
||||
else:
|
||||
return (image,)
|
||||
else:
|
||||
self.skipped_count += 1
|
||||
return None
|
||||
|
||||
def print_summary(self):
|
||||
print(f"Successfully converted {self.successful_count} images.")
|
||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
class JPGImageLoader:
|
||||
def __init__(self, dataset, bytes_column="jpg", label_column=None):
|
||||
self.dataset = dataset
|
||||
self.bytes_column = bytes_column
|
||||
self.label_column = label_column
|
||||
self.successful_count = 0
|
||||
self.skipped_count = 0
|
||||
|
||||
if self.bytes_column not in dataset.columns:
|
||||
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
||||
|
||||
def _get_image(self, item):
|
||||
try:
|
||||
data = item[self.bytes_column]
|
||||
|
||||
if isinstance(data, str) and data.startswith("b'"):
|
||||
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
||||
bytes_data = cleaned_data
|
||||
elif isinstance(data, str):
|
||||
bytes_data = base64.b64decode(data)
|
||||
else:
|
||||
bytes_data = data
|
||||
|
||||
img_bytes = io.BytesIO(bytes_data)
|
||||
image = Image.open(img_bytes)
|
||||
if image.mode == 'RGBA':
|
||||
background = Image.new('RGB', image.size, (0, 0, 0))
|
||||
background.paste(image, mask=image.split()[3])
|
||||
image = background
|
||||
elif image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Error loading image from bytes: {e}")
|
||||
return None
|
||||
|
||||
def get_item(self, idx):
|
||||
item = self.dataset.iloc[idx]
|
||||
image = self._get_image(item)
|
||||
if image is not None:
|
||||
self.successful_count += 1
|
||||
if self.label_column is not None:
|
||||
label = item.get(self.label_column)
|
||||
return (image, label)
|
||||
else:
|
||||
return (image,)
|
||||
else:
|
||||
self.skipped_count += 1
|
||||
return None
|
||||
|
||||
def print_summary(self):
|
||||
print(f"Successfully converted {self.successful_count} images.")
|
||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
self.pretraining = pretraining
|
||||
random.seed(seed)
|
||||
|
||||
sample_value = dataset[column].iloc[0]
|
||||
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
||||
isinstance(sample_value, bytes) or
|
||||
sample_value.startswith("b'") or
|
||||
sample_value.startswith(('b"', 'data:image'))
|
||||
)
|
||||
|
||||
if is_bytes_or_bytestring:
|
||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||
else:
|
||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
||||
|
||||
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
||||
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
||||
else:
|
||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||
|
||||
self.items = []
|
||||
for idx in range(len(dataset)):
|
||||
item = self.loader.get_item(idx)
|
||||
if item is not None: # Only add valid items
|
||||
if self.pretraining:
|
||||
img = item[0] if isinstance(item, tuple) else item
|
||||
self.items.append((img, 'denoise', img))
|
||||
self.items.append((img, 'rotate', 0))
|
||||
else:
|
||||
self.items.append(item)
|
||||
|
||||
if not self.items:
|
||||
raise ValueError("No valid items were loaded from the dataset")
|
||||
|
||||
|
||||
train_indices, val_indices = self._split_data()
|
||||
|
||||
self.train_dataset = self._create_subset(train_indices)
|
||||
self.val_dataset = self._create_subset(val_indices)
|
||||
|
||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
||||
|
||||
def _split_data(self):
|
||||
if len(self.items) == 0:
|
||||
raise ValueError("No items to split")
|
||||
|
||||
num_samples = len(self.items)
|
||||
indices = list(range(num_samples))
|
||||
random.shuffle(indices)
|
||||
|
||||
split_idx = int((1 - self.val_split) * num_samples)
|
||||
train_indices = indices[:split_idx]
|
||||
val_indices = indices[split_idx:]
|
||||
|
||||
return train_indices, val_indices
|
||||
|
||||
def _create_subset(self, indices):
|
||||
subset_items = [self.items[i] for i in indices]
|
||||
return AIIADataset(subset_items, pretraining=self.pretraining)
|
||||
|
||||
class AIIADataset(torch.utils.data.Dataset):
|
||||
def __init__(self, items, pretraining=False):
|
||||
self.items = items
|
||||
self.pretraining = pretraining
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.items[idx]
|
||||
|
||||
if self.pretraining:
|
||||
image, task, label = item
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
noisy_img = image + torch.randn_like(image) * noise_std
|
||||
target = image.clone()
|
||||
return noisy_img, target, task
|
||||
elif task == 'rotate':
|
||||
angles = [0, 90, 180, 270]
|
||||
angle = random.choice(angles)
|
||||
rotated_img = transforms.functional.rotate(image, angle)
|
||||
target = torch.tensor(angle / 90).long()
|
||||
return rotated_img, target, task
|
||||
else:
|
||||
raise ValueError(f"Invalid task at index {idx}: {task}")
|
||||
else:
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image, label = item
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
return image, label
|
||||
else:
|
||||
if isinstance(item, Image.Image):
|
||||
image = self.transform(item)
|
||||
else:
|
||||
image = self.transform(item[0])
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
return image
|
|
@ -0,0 +1,3 @@
|
|||
from .DataLoader import AIIADataLoader
|
||||
|
||||
__all__ = ["AIIADataLoader"]
|
|
@ -0,0 +1,230 @@
|
|||
from .config import AIIAConfig
|
||||
from torch import nn
|
||||
import torch
|
||||
import os
|
||||
import copy
|
||||
|
||||
|
||||
class AIIA(nn.Module):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super(AIIA, self).__init__()
|
||||
# Create a deep copy of the configuration to avoid sharing
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
# Update the config with any additional keyword arguments
|
||||
for key, value in kwargs.items():
|
||||
setattr(self.config, key, value)
|
||||
|
||||
def save(self, path: str):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(self.state_dict(), f"{path}/model.pth")
|
||||
self.config.save(path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
config = AIIAConfig.load(path)
|
||||
model = cls(config)
|
||||
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
||||
return model
|
||||
|
||||
|
||||
class AIIABaseShared(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
"""
|
||||
Initialize the AIIABaseShared model.
|
||||
|
||||
Args:
|
||||
config (AIIAConfig): Configuration object containing model parameters.
|
||||
**kwargs: Additional keyword arguments to override configuration settings.
|
||||
"""
|
||||
super().__init__(config=config, **kwargs)
|
||||
|
||||
# Update configuration with new parameters if provided
|
||||
self. config = copy.deepcopy(config)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self.config, key, value)
|
||||
|
||||
# Initialize the network components
|
||||
self._initialize_network()
|
||||
self._initialize_activation_andPooling()
|
||||
|
||||
def _initialize_network(self):
|
||||
"""Initialize the shared and unique layers of the network."""
|
||||
# Create a single shared convolutional layer
|
||||
self.shared_layer = nn.Conv2d(
|
||||
in_channels=self.config.num_channels,
|
||||
out_channels=self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1 # Using same padding as defined in config
|
||||
)
|
||||
|
||||
# Initialize the unique layers with separate weights and biases
|
||||
self.unique_layers = nn.ModuleList()
|
||||
current_in_channels = self.config.hidden_size
|
||||
|
||||
layer = nn.Conv2d(
|
||||
in_channels=current_in_channels,
|
||||
out_channels=self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1 # Using same padding as defined in config
|
||||
)
|
||||
|
||||
self.unique_layers.append(layer)
|
||||
|
||||
def _initialize_activation_andPooling(self):
|
||||
"""Initialize activation function and pooling layers."""
|
||||
# Get activation function from nn module
|
||||
self.activation = getattr(nn, self.config.activation_function)()
|
||||
|
||||
# Initialize max pooling layer
|
||||
self.max_pool = nn.MaxPool2d(
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass of the network."""
|
||||
# Apply shared layer transformation
|
||||
out = self.shared_layer(x)
|
||||
|
||||
# Pass through activation function
|
||||
out = self.activation(out)
|
||||
|
||||
# Apply max pooling
|
||||
out = self.max_pool(out)
|
||||
|
||||
# Pass through unique layers
|
||||
for unique_layer in self.unique_layers:
|
||||
out = unique_layer(out)
|
||||
out = self.activation(out)
|
||||
out = self.max_pool(out)
|
||||
|
||||
return out
|
||||
|
||||
class AIIABase(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Initialize layers based on configuration
|
||||
layers = []
|
||||
in_channels = self.config.num_channels
|
||||
|
||||
for _ in range(self.config.num_hidden_layers):
|
||||
layers.extend([
|
||||
nn.Conv2d(in_channels, self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size, padding=1),
|
||||
getattr(nn, self.config.activation_function)(),
|
||||
nn.MaxPool2d(kernel_size=1, stride=1)
|
||||
])
|
||||
in_channels = self.config.hidden_size
|
||||
|
||||
self.cnn = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cnn(x)
|
||||
|
||||
class AIIAExpert(AIIA):
|
||||
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Initialize base CNN with configuration and chosen base class
|
||||
if issubclass(base_class, AIIABase):
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
elif issubclass(base_class, AIIABaseShared):
|
||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||
else:
|
||||
raise ValueError("Invalid base class")
|
||||
|
||||
class AIIAmoe(AIIA):
|
||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Update config with new parameters if provided
|
||||
self.config.num_experts = num_experts
|
||||
|
||||
# Initialize multiple experts using chosen base class
|
||||
self.experts = nn.ModuleList([
|
||||
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
||||
for _ in range(self.config.num_experts)
|
||||
])
|
||||
|
||||
# Create gating network
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(self.config.hidden_size, self.config.num_experts),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
||||
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
|
||||
merged_output = torch.sum(
|
||||
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1
|
||||
)
|
||||
return merged_output
|
||||
|
||||
class AIIAchunked(AIIA):
|
||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Update config with new parameters if provided
|
||||
self.config.patch_size = patch_size
|
||||
|
||||
# Initialize base CNN for processing each patch using the specified base class
|
||||
if issubclass(base_class, AIIABase):
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
|
||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||
else:
|
||||
raise ValueError("Invalid base class")
|
||||
|
||||
def forward(self, x):
|
||||
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
|
||||
patch_outputs = []
|
||||
|
||||
for p in torch.split(patches, 1, dim=2):
|
||||
p = p.squeeze(2)
|
||||
po = self.base_cnn(p)
|
||||
patch_outputs.append(po)
|
||||
|
||||
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
||||
class AIIArecursive(AIIA):
|
||||
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
|
||||
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Pass recursion_depth as a kwarg to the config
|
||||
self.config.recursion_depth = recursion_depth
|
||||
|
||||
# Initialize chunked CNN with updated config
|
||||
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
|
||||
|
||||
def forward(self, x, depth=0):
|
||||
if depth == self.recursion_depth:
|
||||
return self.chunked_cnn(x)
|
||||
else:
|
||||
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
|
||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
|
||||
processed_patches = []
|
||||
|
||||
for p in torch.split(patches, 1, dim=2):
|
||||
p = p.squeeze(2)
|
||||
pp = self.forward(p, depth + 1)
|
||||
processed_patches.append(pp)
|
||||
|
||||
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
||||
config = AIIAConfig()
|
||||
model = AIIAmoe(config, num_experts=5)
|
||||
model.save("test")
|
|
@ -0,0 +1,21 @@
|
|||
from .Model import (
|
||||
AIIA,
|
||||
AIIABase,
|
||||
AIIABaseShared,
|
||||
AIIAchunked,
|
||||
AIIAExpert,
|
||||
AIIAmoe,
|
||||
AIIArecursive
|
||||
)
|
||||
from .config import AIIAConfig
|
||||
|
||||
__all__ = [
|
||||
"AIIA",
|
||||
"AIIABase",
|
||||
"AIIABaseShared",
|
||||
"AIIAchunked",
|
||||
"AIIAExpert",
|
||||
"AIIAmoe",
|
||||
"AIIArecursive",
|
||||
"AIIAConfig"
|
||||
]
|
|
@ -0,0 +1,53 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
class AIIAConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "AIIA",
|
||||
kernel_size: int = 3,
|
||||
activation_function: str = "GELU",
|
||||
hidden_size: int = 512,
|
||||
num_hidden_layers: int = 12,
|
||||
num_channels: int = 3,
|
||||
learning_rate: float = 5e-5,
|
||||
**kwargs
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.kernel_size = kernel_size
|
||||
self.activation_function = activation_function
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_channels = num_channels
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
# Store additional keyword arguments as attributes
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
@property
|
||||
def activation_function(self):
|
||||
return self._activation_function
|
||||
|
||||
@activation_function.setter
|
||||
def activation_function(self, value):
|
||||
attr = getattr(nn, value, None)
|
||||
if attr is None or (not callable(attr) and not isinstance(attr, type(nn.Module))):
|
||||
valid_funcs = [func for func in dir(nn) if callable(getattr(nn, func)) or isinstance(getattr(nn, func), type(nn.Module))]
|
||||
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
|
||||
self._activation_function = value
|
||||
|
||||
def save(self, file_path):
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
with open(f"{file_path}/config.json", 'w') as f:
|
||||
json.dump(vars(self), f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def load(cls, file_path):
|
||||
with open(f"{file_path}/config.json", 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
return cls(**config_dict)
|
|
@ -0,0 +1,3 @@
|
|||
from .pretrainer import Pretrainer, ProjectionHead
|
||||
|
||||
__all__ = ["Pretrainer", "ProjectionHead"]
|
|
@ -0,0 +1,230 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import csv
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from ..model.Model import AIIA
|
||||
from ..model.config import AIIAConfig
|
||||
from ..data.DataLoader import AIIADataLoader
|
||||
|
||||
|
||||
class ProjectionHead(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
|
||||
self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
|
||||
|
||||
def forward(self, x, task='denoise'):
|
||||
if task == 'denoise':
|
||||
return self.conv_denoise(x)
|
||||
else:
|
||||
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||
|
||||
class Pretrainer:
|
||||
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
|
||||
"""
|
||||
Initialize the pretrainer with a model.
|
||||
|
||||
Args:
|
||||
model (AIIA): The model instance to pretrain
|
||||
learning_rate (float): Learning rate for optimization
|
||||
config (dict): Model configuration containing hidden_size
|
||||
"""
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model = model.to(self.device)
|
||||
hidden_size = config.hidden_size
|
||||
self.projection_head = ProjectionHead(hidden_size).to(self.device)
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
list(self.model.parameters()) + list(self.projection_head.parameters()),
|
||||
lr=learning_rate
|
||||
)
|
||||
self.train_losses = []
|
||||
self.val_losses = []
|
||||
|
||||
@staticmethod
|
||||
def safe_collate(batch):
|
||||
"""Safely collate batch data handling both denoise and rotate tasks."""
|
||||
denoise_batch = []
|
||||
rotate_batch = []
|
||||
|
||||
for sample in batch:
|
||||
try:
|
||||
noisy_img, target, task = sample
|
||||
if task == 'denoise':
|
||||
denoise_batch.append({
|
||||
'image': noisy_img,
|
||||
'target': target,
|
||||
'task': task
|
||||
})
|
||||
else: # rotate task
|
||||
rotate_batch.append({
|
||||
'image': noisy_img,
|
||||
'target': target,
|
||||
'task': task
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Skipping sample due to error: {e}")
|
||||
continue
|
||||
|
||||
if not denoise_batch and not rotate_batch:
|
||||
return None
|
||||
|
||||
batch_data = {
|
||||
'denoise': None,
|
||||
'rotate': None
|
||||
}
|
||||
|
||||
if denoise_batch:
|
||||
images = torch.stack([x['image'] for x in denoise_batch])
|
||||
targets = torch.stack([x['target'] for x in denoise_batch])
|
||||
batch_data['denoise'] = (images, targets)
|
||||
|
||||
if rotate_batch:
|
||||
images = torch.stack([x['image'] for x in rotate_batch])
|
||||
targets = torch.stack([x['target'] for x in rotate_batch])
|
||||
batch_data['rotate'] = (images, targets)
|
||||
|
||||
return batch_data
|
||||
|
||||
def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True):
|
||||
"""Process a single batch of data."""
|
||||
batch_loss = 0
|
||||
|
||||
if batch_data['denoise'] is not None:
|
||||
noisy_imgs, targets = batch_data['denoise']
|
||||
noisy_imgs = noisy_imgs.to(self.device)
|
||||
targets = targets.to(self.device)
|
||||
|
||||
features = self.model(noisy_imgs)
|
||||
outputs = self.projection_head(features, task='denoise')
|
||||
loss = criterion_denoise(outputs, targets)
|
||||
batch_loss += loss
|
||||
|
||||
if batch_data['rotate'] is not None:
|
||||
imgs, targets = batch_data['rotate']
|
||||
imgs = imgs.to(self.device)
|
||||
targets = targets.long().to(self.device)
|
||||
|
||||
features = self.model(imgs)
|
||||
outputs = self.projection_head(features, task='rotate')
|
||||
loss = criterion_rotate(outputs, targets)
|
||||
batch_loss += loss
|
||||
|
||||
return batch_loss
|
||||
|
||||
def train(self, dataset_paths, column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000):
|
||||
"""
|
||||
Train the model using multiple specified datasets.
|
||||
|
||||
Args:
|
||||
dataset_paths (list): List of paths to parquet datasets
|
||||
num_epochs (int): Number of training epochs
|
||||
batch_size (int): Batch size for training
|
||||
sample_size (int): Number of samples to use from each dataset
|
||||
"""
|
||||
if not dataset_paths:
|
||||
raise ValueError("No dataset paths provided")
|
||||
|
||||
# Read and merge all datasets
|
||||
dataframes = []
|
||||
for path in dataset_paths:
|
||||
try:
|
||||
df = pd.read_parquet(path).head(sample_size)
|
||||
dataframes.append(df)
|
||||
except Exception as e:
|
||||
print(f"Error loading dataset {path}: {e}")
|
||||
|
||||
if not dataframes:
|
||||
raise ValueError("No valid datasets could be loaded")
|
||||
|
||||
merged_df = pd.concat(dataframes, ignore_index=True)
|
||||
|
||||
# Initialize data loader
|
||||
aiia_loader = AIIADataLoader(
|
||||
merged_df,
|
||||
column=column,
|
||||
batch_size=batch_size,
|
||||
pretraining=True,
|
||||
collate_fn=self.safe_collate
|
||||
)
|
||||
|
||||
criterion_denoise = nn.MSELoss()
|
||||
criterion_rotate = nn.CrossEntropyLoss()
|
||||
best_val_loss = float('inf')
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||
print("-" * 20)
|
||||
|
||||
# Training phase
|
||||
self.model.train()
|
||||
self.projection_head.train()
|
||||
total_train_loss = 0.0
|
||||
batch_count = 0
|
||||
|
||||
for batch_data in tqdm(aiia_loader.train_loader):
|
||||
if batch_data is None:
|
||||
continue
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||
|
||||
if batch_loss > 0:
|
||||
batch_loss.backward()
|
||||
self.optimizer.step()
|
||||
total_train_loss += batch_loss.item()
|
||||
batch_count += 1
|
||||
|
||||
avg_train_loss = total_train_loss / max(batch_count, 1)
|
||||
self.train_losses.append(avg_train_loss)
|
||||
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||
|
||||
# Validation phase
|
||||
self.model.eval()
|
||||
self.projection_head.eval()
|
||||
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
self.model.save("AIIA-base-512")
|
||||
print("Best model saved!")
|
||||
|
||||
self.save_losses('losses.csv')
|
||||
|
||||
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
||||
"""Perform validation and return average validation loss."""
|
||||
val_loss = 0.0
|
||||
val_batch_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_data in val_loader:
|
||||
if batch_data is None:
|
||||
continue
|
||||
|
||||
batch_loss = self._process_batch(
|
||||
batch_data, criterion_denoise, criterion_rotate, training=False
|
||||
)
|
||||
|
||||
if batch_loss > 0:
|
||||
val_loss += batch_loss.item()
|
||||
val_batch_count += 1
|
||||
|
||||
avg_val_loss = val_loss / max(val_batch_count, 1)
|
||||
self.val_losses.append(avg_val_loss)
|
||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
return avg_val_loss
|
||||
|
||||
|
||||
def save_losses(self, csv_file):
|
||||
"""Save training and validation losses to a CSV file."""
|
||||
data = list(zip(
|
||||
range(1, len(self.train_losses) + 1),
|
||||
self.train_losses,
|
||||
self.val_losses
|
||||
))
|
||||
|
||||
with open(csv_file, mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
|
||||
writer.writerows(data)
|
||||
print(f"Loss data has been written to {csv_file}")
|
Loading…
Reference in New Issue