updateded size

This commit is contained in:
Falko Victor Habel 2024-08-31 14:14:05 +02:00
parent 4327bbf05a
commit be17dd3d17
1 changed files with 28 additions and 5 deletions

View File

@ -1,16 +1,39 @@
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
from tqdm import tqdm
import pyarrow.parquet as pq
class FakeNewsModelTrainer:
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512):
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512, size_factor=0.5):
self.model_name = model_name
self.max_length = max_length
self.size_factor = size_factor
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Load the original config
original_config = BertConfig.from_pretrained(model_name)
# Calculate new dimensions
new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16)
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
# Create a new config with reduced size
config = BertConfig(
vocab_size=original_config.vocab_size,
hidden_size=new_hidden_size,
num_hidden_layers=new_num_hidden_layers,
num_attention_heads=new_num_attention_heads,
intermediate_size=new_hidden_size * 4,
max_position_embeddings=original_config.max_position_embeddings,
num_labels=2
)
# Initialize the model with the new config
self.model = BertForSequenceClassification(config)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
@ -91,10 +114,10 @@ if __name__ == '__main__':
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42)
# Initialize and train the model
trainer = FakeNewsModelTrainer()
trainer = FakeNewsModelTrainer(size_factor=0.25)
train_data = trainer.prepare_data(train_df)
val_data = trainer.prepare_data(val_df)
trainer.train(train_data, val_data)
# Save the model
trainer.save_model('VeriMind')
trainer.save_model('VeriMindSmall')