inclucded inference into training and added weighted sampler

This commit is contained in:
Falko Victor Habel 2024-08-31 21:40:20 +02:00
parent cd8e1857ea
commit a824724854
1 changed files with 77 additions and 43 deletions

View File

@ -1,39 +1,19 @@
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AdamW from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import pyarrow.parquet as pq import pyarrow.parquet as pq
from sklearn.metrics import classification_report, confusion_matrix
class FakeNewsModelTrainer: class FakeNewsModelTrainer:
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512, size_factor=0.5): def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512):
self.model_name = model_name self.model_name = model_name
self.max_length = max_length self.max_length = max_length
self.size_factor = size_factor
self.tokenizer = BertTokenizer.from_pretrained(model_name) self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Load the original config
original_config = BertConfig.from_pretrained(model_name)
# Calculate new dimensions
new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16)
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
# Create a new config with reduced size
config = BertConfig(
vocab_size=original_config.vocab_size,
hidden_size=new_hidden_size,
num_hidden_layers=new_num_hidden_layers,
num_attention_heads=new_num_attention_heads,
intermediate_size=new_hidden_size * 4,
max_position_embeddings=original_config.max_position_embeddings,
num_labels=2
)
# Initialize the model with the new config
self.model = BertForSequenceClassification(config)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device) self.model.to(self.device)
@ -56,13 +36,23 @@ class FakeNewsModelTrainer:
attention_mask = encoded_texts['attention_mask'] attention_mask = encoded_texts['attention_mask']
labels = torch.tensor(valid_labels) labels = torch.tensor(valid_labels)
return TensorDataset(input_ids, attention_mask, labels) # Create a weighted sampler for balanced batches
class_sample_count = np.array([len(np.where(valid_labels == t)[0]) for t in np.unique(valid_labels)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in valid_labels])
samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
def train(self, train_data, val_data, epochs=13, batch_size=64): return TensorDataset(input_ids, attention_mask, labels), sampler
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size)
optimizer = AdamW(self.model.parameters(), lr=2e-5) def train(self, train_data, val_data, epochs=5, batch_size=32):
train_dataset, train_sampler = train_data
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
optimizer = AdamW(self.model.parameters(), lr=2e-5, eps=1e-8)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
for epoch in range(epochs): for epoch in range(epochs):
self.model.train() self.model.train()
@ -71,24 +61,28 @@ class FakeNewsModelTrainer:
for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{epochs}'): for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{epochs}'):
input_ids, attention_mask, labels = [b.to(self.device) for b in batch] input_ids, attention_mask, labels = [b.to(self.device) for b in batch]
self.model.zero_grad()
outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels) outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss loss = outputs.loss
total_loss += loss.item() total_loss += loss.item()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step() optimizer.step()
optimizer.zero_grad() scheduler.step()
avg_train_loss = total_loss / len(train_dataloader) avg_train_loss = total_loss / len(train_dataloader)
print(f'Average training loss: {avg_train_loss:.4f}') print(f'Average training loss: {avg_train_loss:.4f}')
val_accuracy = self.evaluate(val_dataloader) val_accuracy, val_report = self.evaluate(val_dataloader)
print(f'Validation accuracy: {val_accuracy:.4f}') print(f'Validation accuracy: {val_accuracy:.4f}')
print('Validation Classification Report:')
print(val_report)
def evaluate(self, dataloader): def evaluate(self, dataloader):
self.model.eval() self.model.eval()
correct_predictions = 0 predictions = []
total_predictions = 0 true_labels = []
with torch.no_grad(): with torch.no_grad():
for batch in dataloader: for batch in dataloader:
@ -97,27 +91,67 @@ class FakeNewsModelTrainer:
outputs = self.model(input_ids, attention_mask=attention_mask) outputs = self.model(input_ids, attention_mask=attention_mask)
_, preds = torch.max(outputs.logits, dim=1) _, preds = torch.max(outputs.logits, dim=1)
correct_predictions += torch.sum(preds == labels) predictions.extend(preds.cpu().tolist())
total_predictions += labels.shape[0] true_labels.extend(labels.cpu().tolist())
return correct_predictions.float() / total_predictions accuracy = sum(1 for p, t in zip(predictions, true_labels) if p == t) / len(true_labels)
report = classification_report(true_labels, predictions, target_names=['Fake', 'Real'])
print('Confusion Matrix:')
print(confusion_matrix(true_labels, predictions))
return accuracy, report
def save_model(self, path): def save_model(self, path):
self.model.save_pretrained(path) self.model.save_pretrained(path)
self.tokenizer.save_pretrained(path) self.tokenizer.save_pretrained(path)
class FakeNewsInference:
def __init__(self, model_path):
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(model_path)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def predict(self, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
prediction = torch.argmax(probabilities, dim=1).item()
return prediction, probabilities[0][prediction].item()
# Usage example
if __name__ == '__main__': if __name__ == '__main__':
# Load and preprocess the data
df = pq.read_table('dataset.parquet').to_pandas() df = pq.read_table('dataset.parquet').to_pandas()
# Split the data # Split the data
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42) train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
# Initialize and train the model # Initialize and train the model
trainer = FakeNewsModelTrainer(size_factor=0.5) trainer = FakeNewsModelTrainer()
train_data = trainer.prepare_data(train_df) train_data = trainer.prepare_data(train_df)
val_data = trainer.prepare_data(val_df) val_data = trainer.prepare_data(val_df)[0]
trainer.train(train_data, val_data) trainer.train(train_data, val_data)
# Save the model # Save the model
trainer.save_model('VeriMindSmall') trainer.save_model('VeriMind')
# Inference example
inference = FakeNewsInference('fake_news_detector_model')
sample_texts = [
"Breaking news: Scientists discover new planet in solar system",
"Celebrity secretly lizard person, unnamed sources claim",
"New study shows benefits of regular exercise",
"Government admits to hiding alien life, whistleblower reveals"
]
for text in sample_texts:
prediction, confidence = inference.predict(text)
print(f"Text: {text}")
print(f"Prediction: {'Real' if prediction == 1 else 'Fake'}")
print(f"Confidence: {confidence:.4f}\n")