test_model #13

Merged
Fabel merged 21 commits from test_model into develop 2024-09-03 08:53:54 +00:00
1 changed files with 28 additions and 5 deletions
Showing only changes of commit be17dd3d17 - Show all commits

View File

@ -1,16 +1,39 @@
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
from tqdm import tqdm
import pyarrow.parquet as pq
class FakeNewsModelTrainer:
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512):
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512, size_factor=0.5):
self.model_name = model_name
self.max_length = max_length
self.size_factor = size_factor
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Load the original config
original_config = BertConfig.from_pretrained(model_name)
# Calculate new dimensions
new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16)
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
# Create a new config with reduced size
config = BertConfig(
vocab_size=original_config.vocab_size,
hidden_size=new_hidden_size,
num_hidden_layers=new_num_hidden_layers,
num_attention_heads=new_num_attention_heads,
intermediate_size=new_hidden_size * 4,
max_position_embeddings=original_config.max_position_embeddings,
num_labels=2
)
# Initialize the model with the new config
self.model = BertForSequenceClassification(config)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
@ -91,10 +114,10 @@ if __name__ == '__main__':
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42)
# Initialize and train the model
trainer = FakeNewsModelTrainer()
trainer = FakeNewsModelTrainer(size_factor=0.25)
train_data = trainer.prepare_data(train_df)
val_data = trainer.prepare_data(val_df)
trainer.train(train_data, val_data)
# Save the model
trainer.save_model('VeriMind')
trainer.save_model('VeriMindSmall')