added serilatsation #10

Merged
Fabel merged 1 commits from half_precision into main 2025-02-24 14:17:16 +00:00
1 changed files with 16 additions and 3 deletions

View File

@ -40,14 +40,27 @@ class AIIAConfig:
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
self._activation_function = value
def to_dict(self):
# Recursively converts the object's attributes into serializable Python types.
def serialize(value):
if hasattr(value, "to_dict"):
return value.to_dict()
elif isinstance(value, list):
return [serialize(item) for item in value]
elif isinstance(value, dict):
return {k: serialize(v) for k, v in value.items()}
return value
return {k: serialize(v) for k, v in self.__dict__.items()}
def save(self, file_path):
if not os.path.exists(file_path):
os.makedirs(file_path, exist_ok=True)
with open(f"{file_path}/config.json", 'w') as f:
json.dump(vars(self), f, indent=4)
with open(os.path.join(file_path, "config.json"), "w") as f:
# Save the recursively converted dictionary.
json.dump(self.to_dict(), f, indent=4)
@classmethod
def load(cls, file_path):
with open(f"{file_path}/config.json", 'r') as f:
with open(os.path.join(file_path, "config.json"), "r") as f:
config_dict = json.load(f)
return cls(**config_dict)