working shared model (with way to few params)
This commit is contained in:
parent
6e6f4c4a21
commit
599b8c4835
|
@ -27,6 +27,44 @@ class AIIA(nn.Module):
|
|||
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
||||
return model
|
||||
|
||||
class AIIABaseShared(AIIA):
|
||||
"""
|
||||
Base class with parameter sharing.
|
||||
All hidden layers share the same weights
|
||||
"""
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
# Update config with new parameters if provided
|
||||
for key, value in kwargs.items():
|
||||
setattr(self.config, key, value)
|
||||
|
||||
# Initialize shared layers
|
||||
self.conv_layer = nn.Conv2d(
|
||||
self.config.num_channels,
|
||||
self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1
|
||||
)
|
||||
self.activation_function = getattr(nn, self.config.activation_function)()
|
||||
self.max_pool = nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
# Create a Sequential container with shared layers repeated
|
||||
layers = []
|
||||
for _ in range(self.config.num_hidden_layers):
|
||||
layers.extend([
|
||||
self.conv_layer,
|
||||
self.activation_function,
|
||||
self.max_pool
|
||||
])
|
||||
|
||||
self.cnn = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cnn(x)
|
||||
|
||||
|
||||
class AIIABase(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
|
@ -51,53 +89,61 @@ class AIIABase(AIIA):
|
|||
return self.cnn(x)
|
||||
|
||||
class AIIAExpert(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Initialize base CNN with configuration
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.base_cnn(x)
|
||||
# Initialize base CNN with configuration and chosen base class
|
||||
if issubclass(base_class, AIIABase):
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
elif issubclass(base_class, AIIABaseShared):
|
||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||
else:
|
||||
raise ValueError("Invalid base class")
|
||||
|
||||
class AIIAmoe(AIIA):
|
||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Update config with new parameters if provided
|
||||
self.config.num_experts = num_experts
|
||||
|
||||
# Initialize multiple experts
|
||||
# Initialize multiple experts using chosen base class
|
||||
self.experts = nn.ModuleList([
|
||||
AIIAExpert(self.config, **kwargs) for _ in range(num_experts)
|
||||
])
|
||||
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
||||
for _ in range(self.config.num_experts)
|
||||
])
|
||||
|
||||
# Create gating network
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(self.config.hidden_size, num_experts),
|
||||
nn.Linear(self.config.hidden_size, self.config.num_experts),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
||||
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
|
||||
merged_output = torch.sum(
|
||||
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1
|
||||
)
|
||||
)
|
||||
return merged_output
|
||||
|
||||
class AIIAchunked(AIIA):
|
||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, **kwargs):
|
||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Update config with new parameters if provided
|
||||
self.config.patch_size = patch_size
|
||||
|
||||
# Initialize base CNN for processing each patch
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
# Initialize base CNN for processing each patch using the specified base class
|
||||
if issubclass(base_class, AIIABase):
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
|
||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||
else:
|
||||
raise ValueError("Invalid base class")
|
||||
|
||||
def forward(self, x):
|
||||
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
||||
|
@ -112,8 +158,8 @@ class AIIAchunked(AIIA):
|
|||
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
||||
class AIIAresursive(AIIA):
|
||||
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, **kwargs):
|
||||
class AIIArecursive(AIIA):
|
||||
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
|
||||
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
@ -122,7 +168,7 @@ class AIIAresursive(AIIA):
|
|||
self.config.recursion_depth = recursion_depth
|
||||
|
||||
# Initialize chunked CNN with updated config
|
||||
self.chunked_cnn = AIIAchunked(self.config, **kwargs)
|
||||
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
|
||||
|
||||
def forward(self, x, depth=0):
|
||||
if depth == self.recursion_depth:
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import os
|
||||
import random
|
||||
import pandas as pd
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.model import AIIABase
|
||||
from aiia.data.DataLoader import AIIADataLoader
|
||||
|
||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||
# Merge the two parquet files
|
||||
df1 = pd.read_parquet(data_path1)
|
||||
df2 = pd.read_parquet(data_path2)
|
||||
merged_df = pd.concat([df1, df2], ignore_index=True)
|
||||
|
||||
# Create a new AIIAConfig instance
|
||||
config = AIIAConfig(
|
||||
model_name="AIIA-512x",
|
||||
hidden_size=512,
|
||||
num_hidden_layers=12,
|
||||
kernel_size=5,
|
||||
learning_rate=5e-5
|
||||
)
|
||||
|
||||
# Initialize the base model
|
||||
model = AIIABase(config)
|
||||
|
||||
# Create dataset loader with merged data
|
||||
train_dataset = AIIADataLoader(
|
||||
merged_df,
|
||||
batch_size=32,
|
||||
val_split=0.2,
|
||||
seed=42,
|
||||
column="file_path",
|
||||
label_column=None
|
||||
)
|
||||
|
||||
# Create separate dataloaders for training and validation sets
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset.train_dataset,
|
||||
batch_size=train_dataset.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
train_dataset.val_ataset,
|
||||
batch_size=train_dataset.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4
|
||||
)
|
||||
|
||||
# Initialize loss functions and optimizer
|
||||
criterion_denoise = nn.MSELoss()
|
||||
criterion_rotate = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model.to(device)
|
||||
|
||||
best_val_loss = float('inf')
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||
print("-" * 20)
|
||||
|
||||
# Training phase
|
||||
model.train()
|
||||
total_train_loss = 0.0
|
||||
denoise_train_loss = 0.0
|
||||
rotate_train_loss = 0.0
|
||||
|
||||
for batch in train_dataloader:
|
||||
images, targets, tasks = zip(*batch)
|
||||
|
||||
if device == "cuda":
|
||||
images = [img.cuda() for img in images]
|
||||
targets = [t.cuda() for t in targets]
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Process each sample individually since tasks can vary
|
||||
outputs = []
|
||||
total_loss = 0.0
|
||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
||||
output = model(image.unsqueeze(0))
|
||||
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(output.squeeze(), target)
|
||||
elif task == 'rotate':
|
||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
||||
|
||||
total_loss += loss
|
||||
outputs.append(output)
|
||||
|
||||
avg_loss = total_loss / len(images)
|
||||
avg_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_train_loss += avg_loss.item()
|
||||
# Separate losses for reporting (you'd need to track this based on tasks)
|
||||
|
||||
avg_total_train_loss = total_train_loss / len(train_dataloader)
|
||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
val_losses = []
|
||||
for batch in val_dataloader:
|
||||
images, targets, tasks = zip(*batch)
|
||||
|
||||
if device == "cuda":
|
||||
images = [img.cuda() for img in images]
|
||||
targets = [t.cuda() for t in targets]
|
||||
|
||||
outputs = []
|
||||
total_loss = 0.0
|
||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
||||
output = model(image.unsqueeze(0))
|
||||
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(output.squeeze(), target)
|
||||
elif task == 'rotate':
|
||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
||||
|
||||
total_loss += loss
|
||||
outputs.append(output)
|
||||
|
||||
avg_val_loss = total_loss / len(images)
|
||||
val_losses.append(avg_val_loss.item())
|
||||
|
||||
avg_val_loss = sum(val_losses) / len(val_dataloader)
|
||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
|
||||
# Save the best model
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
model.save("BASEv0.1")
|
||||
print("Best model saved!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet"
|
||||
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||
pretrain_model(data_path1, data_path2, num_epochs=8)
|
Loading…
Reference in New Issue