Merge pull request 'improved cnn' (#2) from improved_shared_cnn into develop
Reviewed-on: Fabel/AIIA#2
This commit is contained in:
commit
d7d2095d4a
|
@ -26,68 +26,81 @@ class AIIA(nn.Module):
|
|||
model = cls(config)
|
||||
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
||||
return model
|
||||
|
||||
|
||||
|
||||
class AIIABaseShared(AIIA):
|
||||
def __init__(self, config: AIIAConfig, num_shared_layers=1, **kwargs):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
"""
|
||||
Initialize the AIIABaseShared model.
|
||||
|
||||
Args:
|
||||
config (AIIAConfig): Configuration object containing model parameters.
|
||||
**kwargs: Additional keyword arguments to override configuration settings.
|
||||
"""
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = copy.deepcopy(config)
|
||||
self.config.num_shared_layers = num_shared_layers
|
||||
# Update config with new parameters if provided
|
||||
|
||||
# Update configuration with new parameters if provided
|
||||
self. config = copy.deepcopy(config)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self.config, key, value)
|
||||
|
||||
# Initialize the network components
|
||||
self._initialize_network()
|
||||
self._initialize_activation_andPooling()
|
||||
|
||||
# Shared layers (early stages) use the same kernel
|
||||
self.shared_layers = nn.ModuleList()
|
||||
for _ in range(self.config.num_shared_layers):
|
||||
layer = nn.Conv2d(
|
||||
self.config.num_channels,
|
||||
self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1
|
||||
)
|
||||
# Initialize with shared weights if it's the first layer
|
||||
if len(self.shared_layers) == 0:
|
||||
self.shared_weights = layer.weight
|
||||
self.shared_biases = nn.ParameterList([
|
||||
nn.Parameter(torch.zeros(self.config.hidden_size))
|
||||
for _ in range(self.config.num_shared_layers)
|
||||
])
|
||||
else:
|
||||
layer.weight = self.shared_weights
|
||||
# Assign separate biases
|
||||
layer.bias = self.shared_biases[len(self.shared_layers)]
|
||||
self.shared_layers.append(layer)
|
||||
|
||||
# Unique layers (later stages) have their own weights and biases
|
||||
def _initialize_network(self):
|
||||
"""Initialize the shared and unique layers of the network."""
|
||||
# Create a single shared convolutional layer
|
||||
self.shared_layer = nn.Conv2d(
|
||||
in_channels=self.config.num_channels,
|
||||
out_channels=self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1 # Using same padding as defined in config
|
||||
)
|
||||
|
||||
# Initialize the unique layers with separate weights and biases
|
||||
self.unique_layers = nn.ModuleList()
|
||||
in_channels = self.config.hidden_size
|
||||
for _ in range(self.config.num_shared_layers):
|
||||
self.unique_layers.append(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1
|
||||
)
|
||||
)
|
||||
current_in_channels = self.config.hidden_size
|
||||
|
||||
layer = nn.Conv2d(
|
||||
in_channels=current_in_channels,
|
||||
out_channels=self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1 # Using same padding as defined in config
|
||||
)
|
||||
|
||||
self.unique_layers.append(layer)
|
||||
|
||||
# Activation and pooling layers
|
||||
self.activation_function = getattr(nn, self.config.activation_function)()
|
||||
self.max_pool = nn.MaxPool2d(self.config.kernel_size)
|
||||
def _initialize_activation_andPooling(self):
|
||||
"""Initialize activation function and pooling layers."""
|
||||
# Get activation function from nn module
|
||||
self.activation = getattr(nn, self.config.activation_function)()
|
||||
|
||||
# Initialize max pooling layer
|
||||
self.max_pool = nn.MaxPool2d(
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1 # Using same padding as in Conv2d layers
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.shared_layers:
|
||||
x = layer(x)
|
||||
x = self.activation_function(x)
|
||||
x = self.max_pool(x)
|
||||
"""Forward pass of the network."""
|
||||
# Apply shared layer transformation
|
||||
out = self.shared_layer(x)
|
||||
|
||||
for layer in self.unique_layers:
|
||||
x = layer(x)
|
||||
x = self.activation_function(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
return x
|
||||
|
||||
# Pass through activation function
|
||||
out = self.activation(out)
|
||||
|
||||
# Apply max pooling
|
||||
out = self.max_pool(out)
|
||||
|
||||
# Pass through unique layers
|
||||
for unique_layer in self.unique_layers:
|
||||
out = unique_layer(out)
|
||||
out = self.activation(out)
|
||||
out = self.max_pool(out)
|
||||
|
||||
return out
|
||||
|
||||
class AIIABase(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue