From a8ddf2b55971859a6549fe363f2513521cc0a2fa Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 13 Apr 2025 22:18:26 +0200 Subject: [PATCH] updated config to not handle ustom model_name; removed ln rate and made it transformer compatible --- src/aiia/model/config.py | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/src/aiia/model/config.py b/src/aiia/model/config.py index a83329a..7bf3d93 100644 --- a/src/aiia/model/config.py +++ b/src/aiia/model/config.py @@ -1,29 +1,25 @@ -import torch +from transformers import PretrainedConfig import torch.nn as nn -import json -import os +class AIIAConfig(PretrainedConfig): + model_type = "AIIA" # Add this class attribute -class AIIAConfig: def __init__( self, - model_name: str = "AIIA", kernel_size: int = 3, activation_function: str = "GELU", hidden_size: int = 512, num_hidden_layers: int = 12, num_channels: int = 3, - learning_rate: float = 5e-5, **kwargs ): - self.model_name = model_name + super().__init__(**kwargs) self.kernel_size = kernel_size self.activation_function = activation_function self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_channels = num_channels - self.learning_rate = learning_rate - + # Store additional keyword arguments as attributes for key, value in kwargs.items(): setattr(self, key, value) @@ -50,17 +46,4 @@ class AIIAConfig: elif isinstance(value, dict): return {k: serialize(v) for k, v in value.items()} return value - return {k: serialize(v) for k, v in self.__dict__.items()} - - def save(self, file_path): - if not os.path.exists(file_path): - os.makedirs(file_path, exist_ok=True) - with open(os.path.join(file_path, "config.json"), "w") as f: - # Save the recursively converted dictionary. - json.dump(self.to_dict(), f, indent=4) - - @classmethod - def load(cls, file_path): - with open(os.path.join(file_path, "config.json"), "r") as f: - config_dict = json.load(f) - return cls(**config_dict) \ No newline at end of file + return {k: serialize(v) for k, v in self.__dict__.items()} \ No newline at end of file