From 1faf34749ad5c84c2f85622a5942c99ab1d5b282 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 24 Feb 2025 15:16:46 +0100 Subject: [PATCH] added serilatsation --- src/aiia/model/config.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/aiia/model/config.py b/src/aiia/model/config.py index 02bc709..a83329a 100644 --- a/src/aiia/model/config.py +++ b/src/aiia/model/config.py @@ -40,14 +40,27 @@ class AIIAConfig: raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}") self._activation_function = value + def to_dict(self): + # Recursively converts the object's attributes into serializable Python types. + def serialize(value): + if hasattr(value, "to_dict"): + return value.to_dict() + elif isinstance(value, list): + return [serialize(item) for item in value] + elif isinstance(value, dict): + return {k: serialize(v) for k, v in value.items()} + return value + return {k: serialize(v) for k, v in self.__dict__.items()} + def save(self, file_path): if not os.path.exists(file_path): os.makedirs(file_path, exist_ok=True) - with open(f"{file_path}/config.json", 'w') as f: - json.dump(vars(self), f, indent=4) + with open(os.path.join(file_path, "config.json"), "w") as f: + # Save the recursively converted dictionary. + json.dump(self.to_dict(), f, indent=4) @classmethod def load(cls, file_path): - with open(f"{file_path}/config.json", 'r') as f: + with open(os.path.join(file_path, "config.json"), "r") as f: config_dict = json.load(f) return cls(**config_dict) \ No newline at end of file