updateded size

This commit is contained in:
Falko Victor Habel 2024-08-31 14:14:05 +02:00
parent 4327bbf05a
commit be17dd3d17
1 changed files with 28 additions and 5 deletions

View File

@ -1,16 +1,39 @@
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, AdamW from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AdamW
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import pyarrow.parquet as pq import pyarrow.parquet as pq
class FakeNewsModelTrainer: class FakeNewsModelTrainer:
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512): def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512, size_factor=0.5):
self.model_name = model_name self.model_name = model_name
self.max_length = max_length self.max_length = max_length
self.size_factor = size_factor
self.tokenizer = BertTokenizer.from_pretrained(model_name) self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Load the original config
original_config = BertConfig.from_pretrained(model_name)
# Calculate new dimensions
new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16)
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
# Create a new config with reduced size
config = BertConfig(
vocab_size=original_config.vocab_size,
hidden_size=new_hidden_size,
num_hidden_layers=new_num_hidden_layers,
num_attention_heads=new_num_attention_heads,
intermediate_size=new_hidden_size * 4,
max_position_embeddings=original_config.max_position_embeddings,
num_labels=2
)
# Initialize the model with the new config
self.model = BertForSequenceClassification(config)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device) self.model.to(self.device)
@ -91,10 +114,10 @@ if __name__ == '__main__':
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42) train_df, val_df = train_test_split(df, test_size=0.35, random_state=42)
# Initialize and train the model # Initialize and train the model
trainer = FakeNewsModelTrainer() trainer = FakeNewsModelTrainer(size_factor=0.25)
train_data = trainer.prepare_data(train_df) train_data = trainer.prepare_data(train_df)
val_data = trainer.prepare_data(val_df) val_data = trainer.prepare_data(val_df)
trainer.train(train_data, val_data) trainer.train(train_data, val_data)
# Save the model # Save the model
trainer.save_model('VeriMind') trainer.save_model('VeriMindSmall')