updated config t add kwrags to support future changes to the config for different models
This commit is contained in:
parent
106539f48a
commit
b87ce68c82
|
@ -11,7 +11,8 @@ class AIIAConfig:
|
||||||
hidden_size: int = 128,
|
hidden_size: int = 128,
|
||||||
num_hidden_layers: int = 2,
|
num_hidden_layers: int = 2,
|
||||||
num_channels: int = 3,
|
num_channels: int = 3,
|
||||||
learning_rate: float = 5e5
|
learning_rate: float = 5e-5,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
|
@ -20,6 +21,10 @@ class AIIAConfig:
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
|
|
||||||
|
# Store additional keyword arguments as attributes
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_function(self):
|
def activation_function(self):
|
||||||
|
@ -35,10 +40,10 @@ class AIIAConfig:
|
||||||
|
|
||||||
def save(self, file_path):
|
def save(self, file_path):
|
||||||
with open(file_path, 'w') as f:
|
with open(file_path, 'w') as f:
|
||||||
json.dump(self.__dict__, f)
|
json.dump(vars(self), f)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, file_path):
|
def load(cls, file_path):
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
return cls(**config_dict)
|
return cls(**config_dict)
|
Loading…
Reference in New Issue