updated saving and first implementation of new additonal parameter handling
This commit is contained in:
parent
26b701fd77
commit
ab58d352c4
|
@ -1,7 +1,7 @@
|
|||
from config import AIIAConfig
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
import os
|
||||
|
||||
class AIIA(nn.Module):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
|
@ -12,56 +12,70 @@ class AIIA(nn.Module):
|
|||
for key, value in kwargs.items():
|
||||
setattr(self.config, key, value)
|
||||
|
||||
def save(self, model_path, config_path):
|
||||
torch.save(self.state_dict(), model_path)
|
||||
self.config.save(config_path)
|
||||
def save(self, path: str):
|
||||
# Create the directory if it doesn't exist
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(self.state_dict(), f"{path}/model.pth")
|
||||
self.config.save(path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config_path, model_path):
|
||||
config = AIIAConfig.load(config_path)
|
||||
def load(cls, path):
|
||||
config = AIIAConfig.load(path)
|
||||
model = cls(config)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
||||
return model
|
||||
|
||||
class AIIABase(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super(AIIABase, self).__init__(config, **kwargs)
|
||||
|
||||
# Initialize layers based on updated config
|
||||
self.config = config
|
||||
super(AIIABase, self).__init__(config=config)
|
||||
# Initialize layers based on configuration
|
||||
layers = []
|
||||
in_channels = self.config.num_channels
|
||||
for _ in range(self.config.num_hidden_layers):
|
||||
in_channels = config.num_channels
|
||||
|
||||
for _ in range(config.num_hidden_layers):
|
||||
layers.extend([
|
||||
nn.Conv2d(in_channels, self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size, padding=1),
|
||||
getattr(nn, self.config.activation_function)(),
|
||||
nn.MaxPool2d(kernel_size=2)
|
||||
nn.Conv2d(in_channels, config.hidden_size,
|
||||
kernel_size=config.kernel_size, padding=1),
|
||||
getattr(nn, config.activation_function)(),
|
||||
nn. MaxPool2d(kernel_size=2)
|
||||
])
|
||||
in_channels = self.config.hidden_size
|
||||
in_channels = config.hidden_size
|
||||
|
||||
self.cnn = nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.cnn(x)
|
||||
|
||||
class AIIAExpert(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super(AIIAExpert, self).__init__(config, **kwargs)
|
||||
self.base_cnn = AIIABase(config, **kwargs)
|
||||
self.config = config
|
||||
super(AIIAExpert, self).__init__(config=config)
|
||||
|
||||
# Initialize base CNN with configuration
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.base_cnn(x)
|
||||
return self. base_cnn(x)
|
||||
|
||||
class AIIAmoe(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super(AIIAmoe, self).__init__(config, **kwargs)
|
||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, **kwargs):
|
||||
self.config = config
|
||||
super(AIIAmoe, self).__init__(config=config)
|
||||
|
||||
# Get num_experts from updated config
|
||||
num_Experts = getattr(self.config, 'num_Experts', 3)
|
||||
self.experts = nn.ModuleList([AIIAExpert(config, **kwargs) for _ in range(num_Experts)])
|
||||
# Update config with new parameters if provided
|
||||
self.config.num_experts = num_experts
|
||||
|
||||
# Update gate based on latest config values
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(self.config.hidden_size, num_Experts),
|
||||
# Initialize multiple experts
|
||||
self.experts = nn.ModuleList([
|
||||
AIIAExpert(self.config, **kwargs) for _ in range(num_experts)
|
||||
])
|
||||
|
||||
# Create gating network
|
||||
self. gate = nn.Sequential(
|
||||
nn.Linear(self.config.hidden_size, num_experts),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
|
@ -74,15 +88,15 @@ class AIIAmoe(AIIA):
|
|||
return merged_output
|
||||
|
||||
class AIIAchunked(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super(AIIAchunked, self).__init__(config, **kwargs)
|
||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, **kwargs):
|
||||
self.config = config
|
||||
super(AIIAchunked, self).__init__(config=config)
|
||||
|
||||
# Get patch_size from updated config
|
||||
patch_size = getattr(self.config, 'patch_size', 16)
|
||||
self.patch_size = patch_size
|
||||
# Update config with new parameters if provided
|
||||
self.config.patch_size = patch_size
|
||||
|
||||
# Initialize base CNN with updated config
|
||||
self.base_cnn = AIIABase(config, **kwargs)
|
||||
# Initialize base CNN for processing each patch
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
||||
|
@ -98,16 +112,16 @@ class AIIAchunked(AIIA):
|
|||
return combined_output
|
||||
|
||||
class AIIAresursive(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super(AIIAresursive, self).__init__(config, **kwargs)
|
||||
|
||||
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, **kwargs):
|
||||
# Pass recursion_depth as a kwarg to the config
|
||||
self.config = config
|
||||
super().__init__(config, recursion_depth=recursion_depth, **kwargs)
|
||||
# Get recursion_depth from updated config
|
||||
recursion_depth = getattr(self.config, 'recursion_depth', 2)
|
||||
self.recursion_depth = recursion_depth
|
||||
|
||||
# Initialize chunked CNN with updated config
|
||||
self.chunked_cnn = AIIAchunked(config, **kwargs)
|
||||
self.recursion_depth = getattr(self.config, 'recursion_depth', 2)
|
||||
|
||||
# Initialize chunked CNN with updated config
|
||||
self.chunked_cnn = AIIAchunked(self.config, **kwargs)
|
||||
|
||||
def forward(self, x, depth=0):
|
||||
if depth == self.recursion_depth:
|
||||
return self.chunked_cnn(x)
|
||||
|
@ -123,3 +137,12 @@ class AIIAresursive(AIIA):
|
|||
|
||||
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
||||
|
||||
|
||||
config = AIIAConfig()
|
||||
|
||||
model = AIIABase(config)
|
||||
model = AIIAmoe(config=config, num_experts=5)
|
||||
model = AIIAresursive(config=config)
|
||||
model.save("moe")
|
|
@ -1,15 +1,17 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
class AIIAConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "AIIA",
|
||||
kernel_size: int = 3,
|
||||
kernel_size: int = 5,
|
||||
activation_function: str = "GELU",
|
||||
hidden_size: int = 128,
|
||||
num_hidden_layers: int = 2,
|
||||
hidden_size: int = 256,
|
||||
num_hidden_layers: int = 12,
|
||||
num_channels: int = 3,
|
||||
learning_rate: float = 5e-5,
|
||||
**kwargs
|
||||
|
@ -39,11 +41,13 @@ class AIIAConfig:
|
|||
self._activation_function = value
|
||||
|
||||
def save(self, file_path):
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(vars(self), f)
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
with open(f"{file_path}/config.json", 'w') as f:
|
||||
json.dump(vars(self), f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def load(cls, file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
with open(f"{file_path}/config.json", 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
return cls(**config_dict)
|
Loading…
Reference in New Issue