added serilatsation

This commit is contained in:
Falko Victor Habel 2025-02-24 15:16:46 +01:00
parent c9b6a8926b
commit 1faf34749a
1 changed files with 16 additions and 3 deletions

View File

@ -40,14 +40,27 @@ class AIIAConfig:
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}") raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
self._activation_function = value self._activation_function = value
def to_dict(self):
# Recursively converts the object's attributes into serializable Python types.
def serialize(value):
if hasattr(value, "to_dict"):
return value.to_dict()
elif isinstance(value, list):
return [serialize(item) for item in value]
elif isinstance(value, dict):
return {k: serialize(v) for k, v in value.items()}
return value
return {k: serialize(v) for k, v in self.__dict__.items()}
def save(self, file_path): def save(self, file_path):
if not os.path.exists(file_path): if not os.path.exists(file_path):
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
with open(f"{file_path}/config.json", 'w') as f: with open(os.path.join(file_path, "config.json"), "w") as f:
json.dump(vars(self), f, indent=4) # Save the recursively converted dictionary.
json.dump(self.to_dict(), f, indent=4)
@classmethod @classmethod
def load(cls, file_path): def load(cls, file_path):
with open(f"{file_path}/config.json", 'r') as f: with open(os.path.join(file_path, "config.json"), "r") as f:
config_dict = json.load(f) config_dict = json.load(f)
return cls(**config_dict) return cls(**config_dict)