feat/tf_support #37

Merged
Fabel merged 13 commits from feat/tf_support into develop 2025-04-16 20:59:48 +00:00
1 changed files with 6 additions and 23 deletions
Showing only changes of commit a8ddf2b559 - Show all commits

View File

@ -1,28 +1,24 @@
import torch from transformers import PretrainedConfig
import torch.nn as nn import torch.nn as nn
import json
import os
class AIIAConfig(PretrainedConfig):
model_type = "AIIA" # Add this class attribute
class AIIAConfig:
def __init__( def __init__(
self, self,
model_name: str = "AIIA",
kernel_size: int = 3, kernel_size: int = 3,
activation_function: str = "GELU", activation_function: str = "GELU",
hidden_size: int = 512, hidden_size: int = 512,
num_hidden_layers: int = 12, num_hidden_layers: int = 12,
num_channels: int = 3, num_channels: int = 3,
learning_rate: float = 5e-5,
**kwargs **kwargs
): ):
self.model_name = model_name super().__init__(**kwargs)
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.activation_function = activation_function self.activation_function = activation_function
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels self.num_channels = num_channels
self.learning_rate = learning_rate
# Store additional keyword arguments as attributes # Store additional keyword arguments as attributes
for key, value in kwargs.items(): for key, value in kwargs.items():
@ -51,16 +47,3 @@ class AIIAConfig:
return {k: serialize(v) for k, v in value.items()} return {k: serialize(v) for k, v in value.items()}
return value return value
return {k: serialize(v) for k, v in self.__dict__.items()} return {k: serialize(v) for k, v in self.__dict__.items()}
def save(self, file_path):
if not os.path.exists(file_path):
os.makedirs(file_path, exist_ok=True)
with open(os.path.join(file_path, "config.json"), "w") as f:
# Save the recursively converted dictionary.
json.dump(self.to_dict(), f, indent=4)
@classmethod
def load(cls, file_path):
with open(os.path.join(file_path, "config.json"), "r") as f:
config_dict = json.load(f)
return cls(**config_dict)