diff --git a/src/model/train.py b/src/model/train.py index 39fd42e..730dcd9 100644 --- a/src/model/train.py +++ b/src/model/train.py @@ -1,16 +1,39 @@ from sklearn.model_selection import train_test_split -from transformers import BertTokenizer, BertForSequenceClassification, AdamW +from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AdamW from torch.utils.data import DataLoader, TensorDataset import torch from tqdm import tqdm import pyarrow.parquet as pq class FakeNewsModelTrainer: - def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512): + def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512, size_factor=0.5): self.model_name = model_name self.max_length = max_length + self.size_factor = size_factor self.tokenizer = BertTokenizer.from_pretrained(model_name) - self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) + + # Load the original config + original_config = BertConfig.from_pretrained(model_name) + + # Calculate new dimensions + new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16) + new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1) + new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1) + + # Create a new config with reduced size + config = BertConfig( + vocab_size=original_config.vocab_size, + hidden_size=new_hidden_size, + num_hidden_layers=new_num_hidden_layers, + num_attention_heads=new_num_attention_heads, + intermediate_size=new_hidden_size * 4, + max_position_embeddings=original_config.max_position_embeddings, + num_labels=2 + ) + + # Initialize the model with the new config + self.model = BertForSequenceClassification(config) + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) @@ -91,10 +114,10 @@ if __name__ == '__main__': train_df, val_df = train_test_split(df, test_size=0.35, random_state=42) # Initialize and train the model - trainer = FakeNewsModelTrainer() + trainer = FakeNewsModelTrainer(size_factor=0.25) train_data = trainer.prepare_data(train_df) val_data = trainer.prepare_data(val_df) trainer.train(train_data, val_data) # Save the model - trainer.save_model('VeriMind') + trainer.save_model('VeriMindSmall') \ No newline at end of file