Merge pull request 'added serilatsation' (#10) from half_precision into main
Reviewed-on: #10
This commit is contained in:
commit
ad0bcaee17
|
@ -40,14 +40,27 @@ class AIIAConfig:
|
|||
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
|
||||
self._activation_function = value
|
||||
|
||||
def to_dict(self):
|
||||
# Recursively converts the object's attributes into serializable Python types.
|
||||
def serialize(value):
|
||||
if hasattr(value, "to_dict"):
|
||||
return value.to_dict()
|
||||
elif isinstance(value, list):
|
||||
return [serialize(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: serialize(v) for k, v in value.items()}
|
||||
return value
|
||||
return {k: serialize(v) for k, v in self.__dict__.items()}
|
||||
|
||||
def save(self, file_path):
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
with open(f"{file_path}/config.json", 'w') as f:
|
||||
json.dump(vars(self), f, indent=4)
|
||||
with open(os.path.join(file_path, "config.json"), "w") as f:
|
||||
# Save the recursively converted dictionary.
|
||||
json.dump(self.to_dict(), f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def load(cls, file_path):
|
||||
with open(f"{file_path}/config.json", 'r') as f:
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config_dict = json.load(f)
|
||||
return cls(**config_dict)
|
Loading…
Reference in New Issue