added serilatsation
This commit is contained in:
parent
c9b6a8926b
commit
1faf34749a
|
@ -40,14 +40,27 @@ class AIIAConfig:
|
||||||
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
|
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
|
||||||
self._activation_function = value
|
self._activation_function = value
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
# Recursively converts the object's attributes into serializable Python types.
|
||||||
|
def serialize(value):
|
||||||
|
if hasattr(value, "to_dict"):
|
||||||
|
return value.to_dict()
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return [serialize(item) for item in value]
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return {k: serialize(v) for k, v in value.items()}
|
||||||
|
return value
|
||||||
|
return {k: serialize(v) for k, v in self.__dict__.items()}
|
||||||
|
|
||||||
def save(self, file_path):
|
def save(self, file_path):
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
with open(f"{file_path}/config.json", 'w') as f:
|
with open(os.path.join(file_path, "config.json"), "w") as f:
|
||||||
json.dump(vars(self), f, indent=4)
|
# Save the recursively converted dictionary.
|
||||||
|
json.dump(self.to_dict(), f, indent=4)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, file_path):
|
def load(cls, file_path):
|
||||||
with open(f"{file_path}/config.json", 'r') as f:
|
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
return cls(**config_dict)
|
return cls(**config_dict)
|
Loading…
Reference in New Issue