working shared model (with way to few params)
This commit is contained in:
parent
6e6f4c4a21
commit
599b8c4835
|
@ -27,6 +27,44 @@ class AIIA(nn.Module):
|
||||||
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
class AIIABaseShared(AIIA):
|
||||||
|
"""
|
||||||
|
Base class with parameter sharing.
|
||||||
|
All hidden layers share the same weights
|
||||||
|
"""
|
||||||
|
def __init__(self, config: AIIAConfig, **kwargs):
|
||||||
|
super().__init__(config=config, **kwargs)
|
||||||
|
self.config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
# Update config with new parameters if provided
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self.config, key, value)
|
||||||
|
|
||||||
|
# Initialize shared layers
|
||||||
|
self.conv_layer = nn.Conv2d(
|
||||||
|
self.config.num_channels,
|
||||||
|
self.config.hidden_size,
|
||||||
|
kernel_size=self.config.kernel_size,
|
||||||
|
padding=1
|
||||||
|
)
|
||||||
|
self.activation_function = getattr(nn, self.config.activation_function)()
|
||||||
|
self.max_pool = nn.MaxPool2d(kernel_size=2)
|
||||||
|
|
||||||
|
# Create a Sequential container with shared layers repeated
|
||||||
|
layers = []
|
||||||
|
for _ in range(self.config.num_hidden_layers):
|
||||||
|
layers.extend([
|
||||||
|
self.conv_layer,
|
||||||
|
self.activation_function,
|
||||||
|
self.max_pool
|
||||||
|
])
|
||||||
|
|
||||||
|
self.cnn = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.cnn(x)
|
||||||
|
|
||||||
|
|
||||||
class AIIABase(AIIA):
|
class AIIABase(AIIA):
|
||||||
def __init__(self, config: AIIAConfig, **kwargs):
|
def __init__(self, config: AIIAConfig, **kwargs):
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config, **kwargs)
|
||||||
|
@ -51,32 +89,35 @@ class AIIABase(AIIA):
|
||||||
return self.cnn(x)
|
return self.cnn(x)
|
||||||
|
|
||||||
class AIIAExpert(AIIA):
|
class AIIAExpert(AIIA):
|
||||||
def __init__(self, config: AIIAConfig, **kwargs):
|
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config, **kwargs)
|
||||||
self.config = self.config
|
self.config = self.config
|
||||||
|
|
||||||
# Initialize base CNN with configuration
|
# Initialize base CNN with configuration and chosen base class
|
||||||
|
if issubclass(base_class, AIIABase):
|
||||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||||
|
elif issubclass(base_class, AIIABaseShared):
|
||||||
def forward(self, x):
|
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||||
return self.base_cnn(x)
|
else:
|
||||||
|
raise ValueError("Invalid base class")
|
||||||
|
|
||||||
class AIIAmoe(AIIA):
|
class AIIAmoe(AIIA):
|
||||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, **kwargs):
|
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config, **kwargs)
|
||||||
self.config = self.config
|
self.config = self.config
|
||||||
|
|
||||||
# Update config with new parameters if provided
|
# Update config with new parameters if provided
|
||||||
self.config.num_experts = num_experts
|
self.config.num_experts = num_experts
|
||||||
|
|
||||||
# Initialize multiple experts
|
# Initialize multiple experts using chosen base class
|
||||||
self.experts = nn.ModuleList([
|
self.experts = nn.ModuleList([
|
||||||
AIIAExpert(self.config, **kwargs) for _ in range(num_experts)
|
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
||||||
|
for _ in range(self.config.num_experts)
|
||||||
])
|
])
|
||||||
|
|
||||||
# Create gating network
|
# Create gating network
|
||||||
self.gate = nn.Sequential(
|
self.gate = nn.Sequential(
|
||||||
nn.Linear(self.config.hidden_size, num_experts),
|
nn.Linear(self.config.hidden_size, self.config.num_experts),
|
||||||
nn.Softmax(dim=1)
|
nn.Softmax(dim=1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -89,15 +130,20 @@ class AIIAmoe(AIIA):
|
||||||
return merged_output
|
return merged_output
|
||||||
|
|
||||||
class AIIAchunked(AIIA):
|
class AIIAchunked(AIIA):
|
||||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, **kwargs):
|
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config, **kwargs)
|
||||||
self.config = self.config
|
self.config = self.config
|
||||||
|
|
||||||
# Update config with new parameters if provided
|
# Update config with new parameters if provided
|
||||||
self.config.patch_size = patch_size
|
self.config.patch_size = patch_size
|
||||||
|
|
||||||
# Initialize base CNN for processing each patch
|
# Initialize base CNN for processing each patch using the specified base class
|
||||||
|
if issubclass(base_class, AIIABase):
|
||||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||||
|
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
|
||||||
|
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid base class")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
||||||
|
@ -112,8 +158,8 @@ class AIIAchunked(AIIA):
|
||||||
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
||||||
return combined_output
|
return combined_output
|
||||||
|
|
||||||
class AIIAresursive(AIIA):
|
class AIIArecursive(AIIA):
|
||||||
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, **kwargs):
|
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
|
||||||
|
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config, **kwargs)
|
||||||
self.config = self.config
|
self.config = self.config
|
||||||
|
@ -122,7 +168,7 @@ class AIIAresursive(AIIA):
|
||||||
self.config.recursion_depth = recursion_depth
|
self.config.recursion_depth = recursion_depth
|
||||||
|
|
||||||
# Initialize chunked CNN with updated config
|
# Initialize chunked CNN with updated config
|
||||||
self.chunked_cnn = AIIAchunked(self.config, **kwargs)
|
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
|
||||||
|
|
||||||
def forward(self, x, depth=0):
|
def forward(self, x, depth=0):
|
||||||
if depth == self.recursion_depth:
|
if depth == self.recursion_depth:
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import pandas as pd
|
||||||
|
from aiia.model.config import AIIAConfig
|
||||||
|
from aiia.model import AIIABase
|
||||||
|
from aiia.data.DataLoader import AIIADataLoader
|
||||||
|
|
||||||
|
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
# Merge the two parquet files
|
||||||
|
df1 = pd.read_parquet(data_path1)
|
||||||
|
df2 = pd.read_parquet(data_path2)
|
||||||
|
merged_df = pd.concat([df1, df2], ignore_index=True)
|
||||||
|
|
||||||
|
# Create a new AIIAConfig instance
|
||||||
|
config = AIIAConfig(
|
||||||
|
model_name="AIIA-512x",
|
||||||
|
hidden_size=512,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
kernel_size=5,
|
||||||
|
learning_rate=5e-5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the base model
|
||||||
|
model = AIIABase(config)
|
||||||
|
|
||||||
|
# Create dataset loader with merged data
|
||||||
|
train_dataset = AIIADataLoader(
|
||||||
|
merged_df,
|
||||||
|
batch_size=32,
|
||||||
|
val_split=0.2,
|
||||||
|
seed=42,
|
||||||
|
column="file_path",
|
||||||
|
label_column=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create separate dataloaders for training and validation sets
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset.train_dataset,
|
||||||
|
batch_size=train_dataset.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
train_dataset.val_ataset,
|
||||||
|
batch_size=train_dataset.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=4
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize loss functions and optimizer
|
||||||
|
criterion_denoise = nn.MSELoss()
|
||||||
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
|
print("-" * 20)
|
||||||
|
|
||||||
|
# Training phase
|
||||||
|
model.train()
|
||||||
|
total_train_loss = 0.0
|
||||||
|
denoise_train_loss = 0.0
|
||||||
|
rotate_train_loss = 0.0
|
||||||
|
|
||||||
|
for batch in train_dataloader:
|
||||||
|
images, targets, tasks = zip(*batch)
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
images = [img.cuda() for img in images]
|
||||||
|
targets = [t.cuda() for t in targets]
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Process each sample individually since tasks can vary
|
||||||
|
outputs = []
|
||||||
|
total_loss = 0.0
|
||||||
|
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
||||||
|
output = model(image.unsqueeze(0))
|
||||||
|
|
||||||
|
if task == 'denoise':
|
||||||
|
loss = criterion_denoise(output.squeeze(), target)
|
||||||
|
elif task == 'rotate':
|
||||||
|
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
||||||
|
|
||||||
|
total_loss += loss
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(images)
|
||||||
|
avg_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_train_loss += avg_loss.item()
|
||||||
|
# Separate losses for reporting (you'd need to track this based on tasks)
|
||||||
|
|
||||||
|
avg_total_train_loss = total_train_loss / len(train_dataloader)
|
||||||
|
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
||||||
|
|
||||||
|
# Validation phase
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
val_losses = []
|
||||||
|
for batch in val_dataloader:
|
||||||
|
images, targets, tasks = zip(*batch)
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
images = [img.cuda() for img in images]
|
||||||
|
targets = [t.cuda() for t in targets]
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
total_loss = 0.0
|
||||||
|
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
||||||
|
output = model(image.unsqueeze(0))
|
||||||
|
|
||||||
|
if task == 'denoise':
|
||||||
|
loss = criterion_denoise(output.squeeze(), target)
|
||||||
|
elif task == 'rotate':
|
||||||
|
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
||||||
|
|
||||||
|
total_loss += loss
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
avg_val_loss = total_loss / len(images)
|
||||||
|
val_losses.append(avg_val_loss.item())
|
||||||
|
|
||||||
|
avg_val_loss = sum(val_losses) / len(val_dataloader)
|
||||||
|
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||||
|
|
||||||
|
# Save the best model
|
||||||
|
if avg_val_loss < best_val_loss:
|
||||||
|
best_val_loss = avg_val_loss
|
||||||
|
model.save("BASEv0.1")
|
||||||
|
print("Best model saved!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet"
|
||||||
|
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||||
|
pretrain_model(data_path1, data_path2, num_epochs=8)
|
Loading…
Reference in New Issue