working shared model (with way to few params)

This commit is contained in:
Falko Victor Habel 2025-01-24 18:04:44 +01:00
parent 6e6f4c4a21
commit 599b8c4835
2 changed files with 216 additions and 21 deletions

View File

@ -27,6 +27,44 @@ class AIIA(nn.Module):
model.load_state_dict(torch.load(f"{path}/model.pth")) model.load_state_dict(torch.load(f"{path}/model.pth"))
return model return model
class AIIABaseShared(AIIA):
"""
Base class with parameter sharing.
All hidden layers share the same weights
"""
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
self.config = copy.deepcopy(config)
# Update config with new parameters if provided
for key, value in kwargs.items():
setattr(self.config, key, value)
# Initialize shared layers
self.conv_layer = nn.Conv2d(
self.config.num_channels,
self.config.hidden_size,
kernel_size=self.config.kernel_size,
padding=1
)
self.activation_function = getattr(nn, self.config.activation_function)()
self.max_pool = nn.MaxPool2d(kernel_size=2)
# Create a Sequential container with shared layers repeated
layers = []
for _ in range(self.config.num_hidden_layers):
layers.extend([
self.conv_layer,
self.activation_function,
self.max_pool
])
self.cnn = nn.Sequential(*layers)
def forward(self, x):
return self.cnn(x)
class AIIABase(AIIA): class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs): def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs) super().__init__(config=config, **kwargs)
@ -51,32 +89,35 @@ class AIIABase(AIIA):
return self.cnn(x) return self.cnn(x)
class AIIAExpert(AIIA): class AIIAExpert(AIIA):
def __init__(self, config: AIIAConfig, **kwargs): def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs) super().__init__(config=config, **kwargs)
self.config = self.config self.config = self.config
# Initialize base CNN with configuration # Initialize base CNN with configuration and chosen base class
if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs) self.base_cnn = AIIABase(self.config, **kwargs)
elif issubclass(base_class, AIIABaseShared):
def forward(self, x): self.base_cnn = AIIABaseShared(self.config, **kwargs)
return self.base_cnn(x) else:
raise ValueError("Invalid base class")
class AIIAmoe(AIIA): class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3, **kwargs): def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs) super().__init__(config=config, **kwargs)
self.config = self.config self.config = self.config
# Update config with new parameters if provided # Update config with new parameters if provided
self.config.num_experts = num_experts self.config.num_experts = num_experts
# Initialize multiple experts # Initialize multiple experts using chosen base class
self.experts = nn.ModuleList([ self.experts = nn.ModuleList([
AIIAExpert(self.config, **kwargs) for _ in range(num_experts) AIIAExpert(self.config, base_class=base_class, **kwargs)
for _ in range(self.config.num_experts)
]) ])
# Create gating network # Create gating network
self.gate = nn.Sequential( self.gate = nn.Sequential(
nn.Linear(self.config.hidden_size, num_experts), nn.Linear(self.config.hidden_size, self.config.num_experts),
nn.Softmax(dim=1) nn.Softmax(dim=1)
) )
@ -89,15 +130,20 @@ class AIIAmoe(AIIA):
return merged_output return merged_output
class AIIAchunked(AIIA): class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16, **kwargs): def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs) super().__init__(config=config, **kwargs)
self.config = self.config self.config = self.config
# Update config with new parameters if provided # Update config with new parameters if provided
self.config.patch_size = patch_size self.config.patch_size = patch_size
# Initialize base CNN for processing each patch # Initialize base CNN for processing each patch using the specified base class
if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs) self.base_cnn = AIIABase(self.config, **kwargs)
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x): def forward(self, x):
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
@ -112,8 +158,8 @@ class AIIAchunked(AIIA):
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0) combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
return combined_output return combined_output
class AIIAresursive(AIIA): class AIIArecursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, **kwargs): def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs) super().__init__(config=config, **kwargs)
self.config = self.config self.config = self.config
@ -122,7 +168,7 @@ class AIIAresursive(AIIA):
self.config.recursion_depth = recursion_depth self.config.recursion_depth = recursion_depth
# Initialize chunked CNN with updated config # Initialize chunked CNN with updated config
self.chunked_cnn = AIIAchunked(self.config, **kwargs) self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
def forward(self, x, depth=0): def forward(self, x, depth=0):
if depth == self.recursion_depth: if depth == self.recursion_depth:

149
src/pretrain.py Normal file
View File

@ -0,0 +1,149 @@
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import random
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader
def pretrain_model(data_path1, data_path2, num_epochs=3):
# Merge the two parquet files
df1 = pd.read_parquet(data_path1)
df2 = pd.read_parquet(data_path2)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Create a new AIIAConfig instance
config = AIIAConfig(
model_name="AIIA-512x",
hidden_size=512,
num_hidden_layers=12,
kernel_size=5,
learning_rate=5e-5
)
# Initialize the base model
model = AIIABase(config)
# Create dataset loader with merged data
train_dataset = AIIADataLoader(
merged_df,
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None
)
# Create separate dataloaders for training and validation sets
train_dataloader = DataLoader(
train_dataset.train_dataset,
batch_size=train_dataset.batch_size,
shuffle=True,
num_workers=4
)
val_dataloader = DataLoader(
train_dataset.val_ataset,
batch_size=train_dataset.batch_size,
shuffle=False,
num_workers=4
)
# Initialize loss functions and optimizer
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
best_val_loss = float('inf')
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
model.train()
total_train_loss = 0.0
denoise_train_loss = 0.0
rotate_train_loss = 0.0
for batch in train_dataloader:
images, targets, tasks = zip(*batch)
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
optimizer.zero_grad()
# Process each sample individually since tasks can vary
outputs = []
total_loss = 0.0
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target)
elif task == 'rotate':
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
total_loss += loss
outputs.append(output)
avg_loss = total_loss / len(images)
avg_loss.backward()
optimizer.step()
total_train_loss += avg_loss.item()
# Separate losses for reporting (you'd need to track this based on tasks)
avg_total_train_loss = total_train_loss / len(train_dataloader)
print(f"Training Loss: {avg_total_train_loss:.4f}")
# Validation phase
model.eval()
with torch.no_grad():
val_losses = []
for batch in val_dataloader:
images, targets, tasks = zip(*batch)
if device == "cuda":
images = [img.cuda() for img in images]
targets = [t.cuda() for t in targets]
outputs = []
total_loss = 0.0
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
output = model(image.unsqueeze(0))
if task == 'denoise':
loss = criterion_denoise(output.squeeze(), target)
elif task == 'rotate':
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
total_loss += loss
outputs.append(output)
avg_val_loss = total_loss / len(images)
val_losses.append(avg_val_loss.item())
avg_val_loss = sum(val_losses) / len(val_dataloader)
print(f"Validation Loss: {avg_val_loss:.4f}")
# Save the best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("BASEv0.1")
print("Best model saved!")
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=8)