diff --git a/src/model/train.py b/src/model/train.py new file mode 100644 index 0000000..6fcafe2 --- /dev/null +++ b/src/model/train.py @@ -0,0 +1,95 @@ +import pandas as pd +import numpy as np +from sklearn.model_selection import train_test_split +from transformers import BertTokenizer, BertForSequenceClassification, AdamW +from torch.utils.data import DataLoader, TensorDataset +import torch +from tqdm import tqdm +import pyarrow.parquet as pq + +class FakeNewsModelTrainer: + def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512): + self.model_name = model_name + self.max_length = max_length + self.tokenizer = BertTokenizer.from_pretrained(model_name) + self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model.to(self.device) + + def prepare_data(self, df): + texts = df['text'].tolist() + labels = df['label'].tolist() + + encoded_texts = self.tokenizer(texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt') + input_ids = encoded_texts['input_ids'] + attention_mask = encoded_texts['attention_mask'] + labels = torch.tensor(labels) + + return TensorDataset(input_ids, attention_mask, labels) + + def train(self, train_data, val_data, epochs=3, batch_size=16): + train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True) + val_dataloader = DataLoader(val_data, batch_size=batch_size) + + optimizer = AdamW(self.model.parameters(), lr=2e-5) + + for epoch in range(epochs): + self.model.train() + total_loss = 0 + + for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{epochs}'): + input_ids, attention_mask, labels = [b.to(self.device) for b in batch] + + outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels) + loss = outputs.loss + total_loss += loss.item() + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + avg_train_loss = total_loss / len(train_dataloader) + print(f'Average training loss: {avg_train_loss:.4f}') + + val_accuracy = self.evaluate(val_dataloader) + print(f'Validation accuracy: {val_accuracy:.4f}') + + def evaluate(self, dataloader): + self.model.eval() + correct_predictions = 0 + total_predictions = 0 + + with torch.no_grad(): + for batch in dataloader: + input_ids, attention_mask, labels = [b.to(self.device) for b in batch] + + outputs = self.model(input_ids, attention_mask=attention_mask) + _, preds = torch.max(outputs.logits, dim=1) + + correct_predictions += torch.sum(preds == labels) + total_predictions += labels.shape[0] + + return correct_predictions.float() / total_predictions + + def save_model(self, path): + self.model.save_pretrained(path) + self.tokenizer.save_pretrained(path) + + +# Usage example +if __name__ == '__main__': + # Load and preprocess the data + df = pq.read_table('your_dataset.parquet').to_pandas() + df['text'] = df['title'] + ' ' + df['text'] # Combine title and text + + # Split the data + train_df, val_df = train_test_split(df, test_size=0.3, random_state=42) + + # Initialize and train the model + trainer = FakeNewsModelTrainer() + train_data = trainer.prepare_data(train_df) + val_data = trainer.prepare_data(val_df) + trainer.train(train_data, val_data) + + # Save the model + trainer.save_model('VeriMind')