feat/tf_support #37
|
@ -1,29 +1,25 @@
|
||||||
import torch
|
from transformers import PretrainedConfig
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
class AIIAConfig(PretrainedConfig):
|
||||||
|
model_type = "AIIA" # Add this class attribute
|
||||||
|
|
||||||
class AIIAConfig:
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = "AIIA",
|
|
||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
activation_function: str = "GELU",
|
activation_function: str = "GELU",
|
||||||
hidden_size: int = 512,
|
hidden_size: int = 512,
|
||||||
num_hidden_layers: int = 12,
|
num_hidden_layers: int = 12,
|
||||||
num_channels: int = 3,
|
num_channels: int = 3,
|
||||||
learning_rate: float = 5e-5,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
super().__init__(**kwargs)
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.activation_function = activation_function
|
self.activation_function = activation_function
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.learning_rate = learning_rate
|
|
||||||
|
|
||||||
# Store additional keyword arguments as attributes
|
# Store additional keyword arguments as attributes
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
@ -50,17 +46,4 @@ class AIIAConfig:
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
return {k: serialize(v) for k, v in value.items()}
|
return {k: serialize(v) for k, v in value.items()}
|
||||||
return value
|
return value
|
||||||
return {k: serialize(v) for k, v in self.__dict__.items()}
|
return {k: serialize(v) for k, v in self.__dict__.items()}
|
||||||
|
|
||||||
def save(self, file_path):
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
|
||||||
with open(os.path.join(file_path, "config.json"), "w") as f:
|
|
||||||
# Save the recursively converted dictionary.
|
|
||||||
json.dump(self.to_dict(), f, indent=4)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, file_path):
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config_dict = json.load(f)
|
|
||||||
return cls(**config_dict)
|
|
Loading…
Reference in New Issue