Merge pull request 'first_pip' (#4) from first_pip into develop

Reviewed-on: Fabel/AIIA#4
This commit is contained in:
Falko Victor Habel 2025-01-29 21:33:31 +00:00
commit 0d5e5688eb
15 changed files with 418 additions and 232 deletions

4
MANIFEST.in Normal file
View File

@ -0,0 +1,4 @@
include LICENSE
include README.md
include requirements.txt
recursive-include src/aiia *

View File

@ -1,2 +1,30 @@
# AIIA # AIIA
## Example Usage:
```Python
from aiia.model import AIIABase
from aiia.model.config import AIIAConfig
from aiia.pretrain import Pretrainer
# Create your model
config = AIIAConfig(model_name="AIIA-Base-512x20k")
model = AIIABase(config)
# Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=1e-4)
# List of dataset paths
dataset_paths = [
"/path/to/dataset1.parquet",
"/path/to/dataset2.parquet"
]
# Start training with multiple datasets
pretrainer.train(
dataset_paths=dataset_paths,
num_epochs=10,
batch_size=2,
sample_size=10000
)
```

27
example.py Normal file
View File

@ -0,0 +1,27 @@
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
from aiia.model import AIIABase
from aiia.model.config import AIIAConfig
from aiia.pretrain import Pretrainer
# Create your model
config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, hidden_size=256)
model = AIIABase(config)
# Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
# List of dataset paths
dataset_paths = [
data_path1,
data_path2
]
# Start training with multiple datasets
pretrainer.train(
dataset_paths=dataset_paths,
num_epochs=10,
batch_size=2,
sample_size=10000
)

8
pyproject.toml Normal file
View File

@ -0,0 +1,8 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[tool.black]
line-length = 88
target-version = ['py37']
include = '\.pyi?$'

5
requirements.txt Normal file
View File

@ -0,0 +1,5 @@
torch>=4.5.0
numpy
tqdm
pytest
pillow

27
run.py Normal file
View File

@ -0,0 +1,27 @@
data_path1 = "/root/training_data/vision-dataset/images_pretrain.parquet"
data_path2 = "/root/training_data/vision-dataset/vector_img_pretrain.parquet"
from aiia.model import AIIABase
from aiia.model.config import AIIAConfig
from aiia.pretrain import Pretrainer
# Create your model
config = AIIAConfig(model_name="AIIA-Base-512x20k")
model = AIIABase(config)
# Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
# List of dataset paths
dataset_paths = [
data_path1,
data_path2
]
# Start training with multiple datasets
pretrainer.train(
dataset_paths=dataset_paths,
num_epochs=10,
batch_size=2,
sample_size=10000
)

26
setup.cfg Normal file
View File

@ -0,0 +1,26 @@
[metadata]
name = aiia
version = 0.1.0
author = Your Name
author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation
long_description = file: README.md
long_description_content_type = text/markdown
url = https://gitea.fabelous.app/Maschine-Learning/AIIA.git
classifiers =
Programming Language :: Python :: 3
License :: OSI Approved :: MIT License
Operating System :: OS Independent
[options]
package_dir =
= src
packages = find:
python_requires = >=3.7
install_requires =
torch>=1.8.0
numpy>=1.19.0
tqdm>=4.62.0
[options.packages.find]
where = src

25
setup.py Normal file
View File

@ -0,0 +1,25 @@
from setuptools import setup, find_packages
setup(
name="aiia",
version="0.1.0",
packages=find_packages(where="src"),
package_dir={"": "src"},
install_requires=[
"torch>=1.8.0",
"numpy>=1.19.0",
"tqdm>=4.62.0",
],
author="Falko Habel",
author_email="falko.habel@gmx.de",
description="AIIA deep learning model implementation",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://gitea.fabelous.app/Maschine-Learning/AIIA.git",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Creative Commons Attribution-NonCommercial 4.0 International",
"Operating System :: OS Independent",
],
python_requires=">=3.10",
)

View File

@ -1,3 +1,7 @@
from .model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAExpert, AIIAmoe, AIIA, AIIArecursive
from .data import AIIADataLoader from .model.config import AIIAConfig
from .model.config import AIIAConfig from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.1.0"

View File

@ -1 +1,3 @@
from .DataLoader import AIIADataLoader from .DataLoader import AIIADataLoader
__all__ = ["AIIADataLoader"]

View File

@ -223,4 +223,8 @@ class AIIArecursive(AIIA):
processed_patches.append(pp) processed_patches.append(pp)
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0) combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output return combined_output
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
model.save("test")

View File

@ -1,2 +1,21 @@
from .Model import (
AIIA,
AIIABase,
AIIABaseShared,
AIIAchunked,
AIIAExpert,
AIIAmoe,
AIIArecursive
)
from .config import AIIAConfig from .config import AIIAConfig
from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIArecursive, AIIABaseShared
__all__ = [
"AIIA",
"AIIABase",
"AIIABaseShared",
"AIIAchunked",
"AIIAExpert",
"AIIAmoe",
"AIIArecursive",
"AIIAConfig"
]

View File

@ -0,0 +1,3 @@
from .pretrainer import Pretrainer, ProjectionHead
__all__ = ["Pretrainer", "ProjectionHead"]

View File

@ -0,0 +1,230 @@
import torch
from torch import nn
import csv
import pandas as pd
from tqdm import tqdm
from ..model.Model import AIIA
from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader
class ProjectionHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
"""
Initialize the pretrainer with a model.
Args:
model (AIIA): The model instance to pretrain
learning_rate (float): Learning rate for optimization
config (dict): Model configuration containing hidden_size
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = model.to(self.device)
hidden_size = config.hidden_size
self.projection_head = ProjectionHead(hidden_size).to(self.device)
self.optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(self.projection_head.parameters()),
lr=learning_rate
)
self.train_losses = []
self.val_losses = []
@staticmethod
def safe_collate(batch):
"""Safely collate batch data handling both denoise and rotate tasks."""
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True):
"""Process a single batch of data."""
batch_loss = 0
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(self.device)
targets = targets.to(self.device)
features = self.model(noisy_imgs)
outputs = self.projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(self.device)
targets = targets.long().to(self.device)
features = self.model(imgs)
outputs = self.projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
return batch_loss
def train(self, dataset_paths, column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000):
"""
Train the model using multiple specified datasets.
Args:
dataset_paths (list): List of paths to parquet datasets
num_epochs (int): Number of training epochs
batch_size (int): Batch size for training
sample_size (int): Number of samples to use from each dataset
"""
if not dataset_paths:
raise ValueError("No dataset paths provided")
# Read and merge all datasets
dataframes = []
for path in dataset_paths:
try:
df = pd.read_parquet(path).head(sample_size)
dataframes.append(df)
except Exception as e:
print(f"Error loading dataset {path}: {e}")
if not dataframes:
raise ValueError("No valid datasets could be loaded")
merged_df = pd.concat(dataframes, ignore_index=True)
# Initialize data loader
aiia_loader = AIIADataLoader(
merged_df,
column=column,
batch_size=batch_size,
pretraining=True,
collate_fn=self.safe_collate
)
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(aiia_loader.train_loader):
if batch_data is None:
continue
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0:
batch_loss.backward()
self.optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
self.model.eval()
self.projection_head.eval()
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save("AIIA-base-512")
print("Best model saved!")
self.save_losses('losses.csv')
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss."""
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
for batch_data in val_loader:
if batch_data is None:
continue
batch_loss = self._process_batch(
batch_data, criterion_denoise, criterion_rotate, training=False
)
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
self.val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
return avg_val_loss
def save_losses(self, csv_file):
"""Save training and validation losses to a CSV file."""
data = list(zip(
range(1, len(self.train_losses) + 1),
self.train_losses,
self.val_losses
))
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
writer.writerows(data)
print(f"Loss data has been written to {csv_file}")

View File

@ -1,226 +0,0 @@
import torch
from torch import nn
import csv
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader
from tqdm import tqdm
class ProjectionHead(nn.Module):
def __init__(self):
super().__init__()
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
return self.conv_denoise(x)
else:
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
def pretrain_model(data_path1, data_path2, num_epochs=3):
# Read and merge datasets
df1 = pd.read_parquet(data_path1).head(10000)
df2 = pd.read_parquet(data_path2).head(10000)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Model configuration
config = AIIAConfig(
model_name="AIIA-Base-512x20k",
)
# Initialize model and projection head
model = AIIABase(config)
projection_head = ProjectionHead()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
projection_head.to(device)
def safe_collate(batch):
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
aiia_loader = AIIADataLoader(
merged_df,
column="image_bytes",
batch_size=2,
pretraining=True,
collate_fn=safe_collate
)
train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
# Update optimizer to include projection head parameters
optimizer = torch.optim.AdamW(
list(model.parameters()) + list(projection_head.parameters()),
lr=config.learning_rate
)
best_val_loss = float('inf')
train_losses = []
val_losses = []
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
model.train()
projection_head.train()
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(train_loader):
if batch_data is None:
continue
optimizer.zero_grad()
batch_loss = 0
# Handle denoise task
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
# Get features from base model
features = model(noisy_imgs)
# Project features back to image space
outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
# Handle rotate task
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
# Get features from base model
features = model(imgs)
# Project features to rotation predictions
outputs = projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
batch_loss.backward()
optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
model.eval()
projection_head.eval()
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
for batch_data in val_loader:
if batch_data is None:
continue
batch_loss = 0
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
features = model(noisy_imgs)
outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
features = model(imgs)
outputs = projection_head(features, task='rotate')
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
# Save both model and projection head
model.save("AIIA-base-512")
print("Best model saved!")
# Prepare the data to be written to the CSV file
data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses))
# Specify the CSV file name
csv_file = 'losses.csv'
# Write the data to the CSV file
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
# Write the header
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
# Write the data
writer.writerows(data)
print(f"Data has been written to {csv_file}")
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=10)