feat/tf_support #37

Merged
Fabel merged 13 commits from feat/tf_support into develop 2025-04-16 20:59:48 +00:00
1 changed files with 6 additions and 23 deletions
Showing only changes of commit a8ddf2b559 - Show all commits

View File

@ -1,28 +1,24 @@
import torch
from transformers import PretrainedConfig
import torch.nn as nn
import json
import os
class AIIAConfig(PretrainedConfig):
model_type = "AIIA" # Add this class attribute
class AIIAConfig:
def __init__(
self,
model_name: str = "AIIA",
kernel_size: int = 3,
activation_function: str = "GELU",
hidden_size: int = 512,
num_hidden_layers: int = 12,
num_channels: int = 3,
learning_rate: float = 5e-5,
**kwargs
):
self.model_name = model_name
super().__init__(**kwargs)
self.kernel_size = kernel_size
self.activation_function = activation_function
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels
self.learning_rate = learning_rate
# Store additional keyword arguments as attributes
for key, value in kwargs.items():
@ -51,16 +47,3 @@ class AIIAConfig:
return {k: serialize(v) for k, v in value.items()}
return value
return {k: serialize(v) for k, v in self.__dict__.items()}
def save(self, file_path):
if not os.path.exists(file_path):
os.makedirs(file_path, exist_ok=True)
with open(os.path.join(file_path, "config.json"), "w") as f:
# Save the recursively converted dictionary.
json.dump(self.to_dict(), f, indent=4)
@classmethod
def load(cls, file_path):
with open(os.path.join(file_path, "config.json"), "r") as f:
config_dict = json.load(f)
return cls(**config_dict)