updateded size
This commit is contained in:
parent
4327bbf05a
commit
be17dd3d17
|
@ -1,16 +1,39 @@
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
|
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AdamW
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
class FakeNewsModelTrainer:
|
class FakeNewsModelTrainer:
|
||||||
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512):
|
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512, size_factor=0.5):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
self.size_factor = size_factor
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
||||||
self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
|
||||||
|
# Load the original config
|
||||||
|
original_config = BertConfig.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# Calculate new dimensions
|
||||||
|
new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16)
|
||||||
|
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
|
||||||
|
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
|
||||||
|
|
||||||
|
# Create a new config with reduced size
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=original_config.vocab_size,
|
||||||
|
hidden_size=new_hidden_size,
|
||||||
|
num_hidden_layers=new_num_hidden_layers,
|
||||||
|
num_attention_heads=new_num_attention_heads,
|
||||||
|
intermediate_size=new_hidden_size * 4,
|
||||||
|
max_position_embeddings=original_config.max_position_embeddings,
|
||||||
|
num_labels=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the model with the new config
|
||||||
|
self.model = BertForSequenceClassification(config)
|
||||||
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
@ -91,10 +114,10 @@ if __name__ == '__main__':
|
||||||
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42)
|
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42)
|
||||||
|
|
||||||
# Initialize and train the model
|
# Initialize and train the model
|
||||||
trainer = FakeNewsModelTrainer()
|
trainer = FakeNewsModelTrainer(size_factor=0.25)
|
||||||
train_data = trainer.prepare_data(train_df)
|
train_data = trainer.prepare_data(train_df)
|
||||||
val_data = trainer.prepare_data(val_df)
|
val_data = trainer.prepare_data(val_df)
|
||||||
trainer.train(train_data, val_data)
|
trainer.train(train_data, val_data)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
trainer.save_model('VeriMind')
|
trainer.save_model('VeriMindSmall')
|
Reference in New Issue