From a82472485468254247b1d95f9fb6caa70c4df58d Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 31 Aug 2024 21:40:20 +0200 Subject: [PATCH] inclucded inference into training and added weighted sampler --- src/model/train.py | 120 +++++++++++++++++++++++++++++---------------- 1 file changed, 77 insertions(+), 43 deletions(-) diff --git a/src/model/train.py b/src/model/train.py index 7dd9411..bd25bf6 100644 --- a/src/model/train.py +++ b/src/model/train.py @@ -1,39 +1,19 @@ +import pandas as pd +import numpy as np from sklearn.model_selection import train_test_split -from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AdamW -from torch.utils.data import DataLoader, TensorDataset +from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup +from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler import torch from tqdm import tqdm import pyarrow.parquet as pq +from sklearn.metrics import classification_report, confusion_matrix class FakeNewsModelTrainer: - def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512, size_factor=0.5): + def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512): self.model_name = model_name self.max_length = max_length - self.size_factor = size_factor self.tokenizer = BertTokenizer.from_pretrained(model_name) - - # Load the original config - original_config = BertConfig.from_pretrained(model_name) - - # Calculate new dimensions - new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16) - new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1) - new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1) - - # Create a new config with reduced size - config = BertConfig( - vocab_size=original_config.vocab_size, - hidden_size=new_hidden_size, - num_hidden_layers=new_num_hidden_layers, - num_attention_heads=new_num_attention_heads, - intermediate_size=new_hidden_size * 4, - max_position_embeddings=original_config.max_position_embeddings, - num_labels=2 - ) - - # Initialize the model with the new config - self.model = BertForSequenceClassification(config) - + self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) @@ -56,13 +36,23 @@ class FakeNewsModelTrainer: attention_mask = encoded_texts['attention_mask'] labels = torch.tensor(valid_labels) - return TensorDataset(input_ids, attention_mask, labels) + # Create a weighted sampler for balanced batches + class_sample_count = np.array([len(np.where(valid_labels == t)[0]) for t in np.unique(valid_labels)]) + weight = 1. / class_sample_count + samples_weight = np.array([weight[t] for t in valid_labels]) + samples_weight = torch.from_numpy(samples_weight) + sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight)) - def train(self, train_data, val_data, epochs=13, batch_size=64): - train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True) - val_dataloader = DataLoader(val_data, batch_size=batch_size) + return TensorDataset(input_ids, attention_mask, labels), sampler - optimizer = AdamW(self.model.parameters(), lr=2e-5) + def train(self, train_data, val_data, epochs=5, batch_size=32): + train_dataset, train_sampler = train_data + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size) + val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False) + + optimizer = AdamW(self.model.parameters(), lr=2e-5, eps=1e-8) + total_steps = len(train_dataloader) * epochs + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps) for epoch in range(epochs): self.model.train() @@ -71,24 +61,28 @@ class FakeNewsModelTrainer: for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{epochs}'): input_ids, attention_mask, labels = [b.to(self.device) for b in batch] + self.model.zero_grad() outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss total_loss += loss.item() loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) optimizer.step() - optimizer.zero_grad() + scheduler.step() avg_train_loss = total_loss / len(train_dataloader) print(f'Average training loss: {avg_train_loss:.4f}') - val_accuracy = self.evaluate(val_dataloader) + val_accuracy, val_report = self.evaluate(val_dataloader) print(f'Validation accuracy: {val_accuracy:.4f}') + print('Validation Classification Report:') + print(val_report) def evaluate(self, dataloader): self.model.eval() - correct_predictions = 0 - total_predictions = 0 + predictions = [] + true_labels = [] with torch.no_grad(): for batch in dataloader: @@ -97,27 +91,67 @@ class FakeNewsModelTrainer: outputs = self.model(input_ids, attention_mask=attention_mask) _, preds = torch.max(outputs.logits, dim=1) - correct_predictions += torch.sum(preds == labels) - total_predictions += labels.shape[0] + predictions.extend(preds.cpu().tolist()) + true_labels.extend(labels.cpu().tolist()) - return correct_predictions.float() / total_predictions + accuracy = sum(1 for p, t in zip(predictions, true_labels) if p == t) / len(true_labels) + report = classification_report(true_labels, predictions, target_names=['Fake', 'Real']) + + print('Confusion Matrix:') + print(confusion_matrix(true_labels, predictions)) + + return accuracy, report def save_model(self, path): self.model.save_pretrained(path) self.tokenizer.save_pretrained(path) +class FakeNewsInference: + def __init__(self, model_path): + self.tokenizer = BertTokenizer.from_pretrained(model_path) + self.model = BertForSequenceClassification.from_pretrained(model_path) + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model.to(self.device) + self.model.eval() + def predict(self, text): + inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + probabilities = torch.softmax(outputs.logits, dim=1) + prediction = torch.argmax(probabilities, dim=1).item() + + return prediction, probabilities[0][prediction].item() + +# Usage example if __name__ == '__main__': + # Load and preprocess the data df = pq.read_table('dataset.parquet').to_pandas() # Split the data - train_df, val_df = train_test_split(df, test_size=0.35, random_state=42) + train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label']) # Initialize and train the model - trainer = FakeNewsModelTrainer(size_factor=0.5) + trainer = FakeNewsModelTrainer() train_data = trainer.prepare_data(train_df) - val_data = trainer.prepare_data(val_df) + val_data = trainer.prepare_data(val_df)[0] trainer.train(train_data, val_data) # Save the model - trainer.save_model('VeriMindSmall') \ No newline at end of file + trainer.save_model('VeriMind') + + # Inference example + inference = FakeNewsInference('fake_news_detector_model') + sample_texts = [ + "Breaking news: Scientists discover new planet in solar system", + "Celebrity secretly lizard person, unnamed sources claim", + "New study shows benefits of regular exercise", + "Government admits to hiding alien life, whistleblower reveals" + ] + for text in sample_texts: + prediction, confidence = inference.predict(text) + print(f"Text: {text}") + print(f"Prediction: {'Real' if prediction == 1 else 'Fake'}") + print(f"Confidence: {confidence:.4f}\n") \ No newline at end of file