updated config t add kwrags to support future changes to the config for different models

This commit is contained in:
Falko Victor Habel 2025-01-21 21:44:16 +01:00
parent 106539f48a
commit b87ce68c82
1 changed files with 8 additions and 3 deletions

View File

@ -11,7 +11,8 @@ class AIIAConfig:
hidden_size: int = 128, hidden_size: int = 128,
num_hidden_layers: int = 2, num_hidden_layers: int = 2,
num_channels: int = 3, num_channels: int = 3,
learning_rate: float = 5e5 learning_rate: float = 5e-5,
**kwargs
): ):
self.model_name = model_name self.model_name = model_name
self.kernel_size = kernel_size self.kernel_size = kernel_size
@ -20,6 +21,10 @@ class AIIAConfig:
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels self.num_channels = num_channels
self.learning_rate = learning_rate self.learning_rate = learning_rate
# Store additional keyword arguments as attributes
for key, value in kwargs.items():
setattr(self, key, value)
@property @property
def activation_function(self): def activation_function(self):
@ -35,10 +40,10 @@ class AIIAConfig:
def save(self, file_path): def save(self, file_path):
with open(file_path, 'w') as f: with open(file_path, 'w') as f:
json.dump(self.__dict__, f) json.dump(vars(self), f)
@classmethod @classmethod
def load(cls, file_path): def load(cls, file_path):
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
config_dict = json.load(f) config_dict = json.load(f)
return cls(**config_dict) return cls(**config_dict)