working shared model (with way to few params)

This commit is contained in:
Falko Victor Habel 2025-01-24 18:04:44 +01:00
parent 6e6f4c4a21
commit 599b8c4835
2 changed files with 216 additions and 21 deletions

View File

@ -27,6 +27,44 @@ class AIIA(nn.Module):
model.load_state_dict(torch.load(f"{path}/model.pth"))
return model
class AIIABaseShared(AIIA):
"""
Base class with parameter sharing.
All hidden layers share the same weights
"""
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
self.config = copy.deepcopy(config)
# Update config with new parameters if provided
for key, value in kwargs.items():
setattr(self.config, key, value)
# Initialize shared layers
self.conv_layer = nn.Conv2d(
self.config.num_channels,
self.config.hidden_size,
kernel_size=self.config.kernel_size,
padding=1
)
self.activation_function = getattr(nn, self.config.activation_function)()
self.max_pool = nn.MaxPool2d(kernel_size=2)
# Create a Sequential container with shared layers repeated
layers = []
for _ in range(self.config.num_hidden_layers):
layers.extend([
self.conv_layer,
self.activation_function,
self.max_pool
])
self.cnn = nn.Sequential(*layers)
def forward(self, x):
return self.cnn(x)
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
@ -51,53 +89,61 @@ class AIIABase(AIIA):
return self.cnn(x)
class AIIAExpert(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize base CNN with configuration
self.base_cnn = AIIABase(self.config, **kwargs)
def forward(self, x):
return self.base_cnn(x)
# Initialize base CNN with configuration and chosen base class
if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs)
elif issubclass(base_class, AIIABaseShared):
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3, **kwargs):
super().__init__(config=config, **kwargs)
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Update config with new parameters if provided
self.config.num_experts = num_experts
# Initialize multiple experts
# Initialize multiple experts using chosen base class
self.experts = nn.ModuleList([
AIIAExpert(self.config, **kwargs) for _ in range(num_experts)
])
AIIAExpert(self.config, base_class=base_class, **kwargs)
for _ in range(self.config.num_experts)
])
# Create gating network
self.gate = nn.Sequential(
nn.Linear(self.config.hidden_size, num_experts),
nn.Linear(self.config.hidden_size, self.config.num_experts),
nn.Softmax(dim=1)
)
)
def forward(self, x):
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
merged_output = torch.sum(
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1
)
)
return merged_output
class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16, **kwargs):
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Update config with new parameters if provided
self.config.patch_size = patch_size
# Initialize base CNN for processing each patch
self.base_cnn = AIIABase(self.config, **kwargs)
# Initialize base CNN for processing each patch using the specified base class
if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs)
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x):
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
@ -112,8 +158,8 @@ class AIIAchunked(AIIA):
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
return combined_output
class AIIAresursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, **kwargs):
class AIIArecursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
@ -122,7 +168,7 @@ class AIIAresursive(AIIA):
self.config.recursion_depth = recursion_depth
# Initialize chunked CNN with updated config
self.chunked_cnn = AIIAchunked(self.config, **kwargs)
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
def forward(self, x, depth=0):
if depth == self.recursion_depth:

149
src/pretrain.py Normal file
View File

@ -0,0 +1,149 @@
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import random
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader
def pretrain_model(data_path1, data_path2, num_epochs=3):
# Merge the two parquet files
df1 = pd.read_parquet(data_path1)
df2 = pd.read_parquet(data_path2)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Create a new AIIAConfig instance
config = AIIAConfig(
model_name="AIIA-512x",
hidden_size=512,
num_hidden_layers=12,
kernel_size=5,
learning_rate=5e-5
)
# Initialize the base model
model = AIIABase(config)
# Create dataset loader with merged data
train_dataset = AIIADataLoader(
merged_df,
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None
)
# Create separate dataloaders for training and validation sets
train_dataloader = DataLoader(
train_dataset.train_dataset,
batch_size=train_dataset.batch_size,
shuffle=True,
num_workers=4
)
val_dataloader = DataLoader(
train_dataset.val_ataset,
batch_size=train_dataset.batch_size,
shuffle=False,
num_workers=4
)
# Initialize loss functions and optimizer
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
best_val_loss = float('inf')
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
model.train()
total_train_loss = 0.0
denoise_train_loss = 0.0
rotate_train_loss = 0.0
for batch in train_dataloader:
images, targets, tasks = zip(*batch)
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
optimizer.zero_grad()
# Process each sample individually since tasks can vary
outputs = []
total_loss = 0.0
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target)
elif task == 'rotate':
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
total_loss += loss
outputs.append(output)
avg_loss = total_loss / len(images)
avg_loss.backward()
optimizer.step()
total_train_loss += avg_loss.item()
# Separate losses for reporting (you'd need to track this based on tasks)
avg_total_train_loss = total_train_loss / len(train_dataloader)
print(f"Training Loss: {avg_total_train_loss:.4f}")
# Validation phase
model.eval()
with torch.no_grad():
val_losses = []
for batch in val_dataloader:
images, targets, tasks = zip(*batch)
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
outputs = []
total_loss = 0.0
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target)
elif task == 'rotate':
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
total_loss += loss
outputs.append(output)
avg_val_loss = total_loss / len(images)
val_losses.append(avg_val_loss.item())
avg_val_loss = sum(val_losses) / len(val_dataloader)
print(f"Validation Loss: {avg_val_loss:.4f}")
# Save the best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("BASEv0.1")
print("Best model saved!")
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=8)