bert config added

This commit is contained in:
Falko Victor Habel 2024-08-31 21:40:33 +02:00
parent a824724854
commit 306cd5619d
1 changed files with 24 additions and 2 deletions

View File

@ -1,7 +1,7 @@
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup from transformers import BertTokenizer, BertConfig, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import torch import torch
from tqdm import tqdm from tqdm import tqdm
@ -9,7 +9,7 @@ import pyarrow.parquet as pq
from sklearn.metrics import classification_report, confusion_matrix from sklearn.metrics import classification_report, confusion_matrix
class FakeNewsModelTrainer: class FakeNewsModelTrainer:
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512): def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512, size_factor=0.5):
self.model_name = model_name self.model_name = model_name
self.max_length = max_length self.max_length = max_length
self.tokenizer = BertTokenizer.from_pretrained(model_name) self.tokenizer = BertTokenizer.from_pretrained(model_name)
@ -17,6 +17,28 @@ class FakeNewsModelTrainer:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device) self.model.to(self.device)
# Load the original config
original_config = BertConfig.from_pretrained(model_name)
# Calculate new dimensions
new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16)
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
# Create a new config with reduced size
config = BertConfig(
vocab_size=original_config.vocab_size,
hidden_size=new_hidden_size,
num_hidden_layers=new_num_hidden_layers,
num_attention_heads=new_num_attention_heads,
intermediate_size=new_hidden_size * 4,
max_position_embeddings=original_config.max_position_embeddings,
num_labels=2
)
# Initialize the model with the new config
self.model = BertForSequenceClassification(config)
def prepare_data(self, df): def prepare_data(self, df):
texts = df.apply(lambda row: f"{row['title'] or ''} {row['text'] or ''}".strip(), axis=1).tolist() texts = df.apply(lambda row: f"{row['title'] or ''} {row['text'] or ''}".strip(), axis=1).tolist()
labels = df['label'].tolist() labels = df['label'].tolist()