inclucded inference into training and added weighted sampler
This commit is contained in:
parent
cd8e1857ea
commit
a824724854
|
@ -1,39 +1,19 @@
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, AdamW
|
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
|
from sklearn.metrics import classification_report, confusion_matrix
|
||||||
|
|
||||||
class FakeNewsModelTrainer:
|
class FakeNewsModelTrainer:
|
||||||
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512, size_factor=0.5):
|
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.size_factor = size_factor
|
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
||||||
|
self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
||||||
# Load the original config
|
|
||||||
original_config = BertConfig.from_pretrained(model_name)
|
|
||||||
|
|
||||||
# Calculate new dimensions
|
|
||||||
new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16)
|
|
||||||
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
|
|
||||||
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
|
|
||||||
|
|
||||||
# Create a new config with reduced size
|
|
||||||
config = BertConfig(
|
|
||||||
vocab_size=original_config.vocab_size,
|
|
||||||
hidden_size=new_hidden_size,
|
|
||||||
num_hidden_layers=new_num_hidden_layers,
|
|
||||||
num_attention_heads=new_num_attention_heads,
|
|
||||||
intermediate_size=new_hidden_size * 4,
|
|
||||||
max_position_embeddings=original_config.max_position_embeddings,
|
|
||||||
num_labels=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the model with the new config
|
|
||||||
self.model = BertForSequenceClassification(config)
|
|
||||||
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
@ -56,13 +36,23 @@ class FakeNewsModelTrainer:
|
||||||
attention_mask = encoded_texts['attention_mask']
|
attention_mask = encoded_texts['attention_mask']
|
||||||
labels = torch.tensor(valid_labels)
|
labels = torch.tensor(valid_labels)
|
||||||
|
|
||||||
return TensorDataset(input_ids, attention_mask, labels)
|
# Create a weighted sampler for balanced batches
|
||||||
|
class_sample_count = np.array([len(np.where(valid_labels == t)[0]) for t in np.unique(valid_labels)])
|
||||||
|
weight = 1. / class_sample_count
|
||||||
|
samples_weight = np.array([weight[t] for t in valid_labels])
|
||||||
|
samples_weight = torch.from_numpy(samples_weight)
|
||||||
|
sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
|
||||||
|
|
||||||
def train(self, train_data, val_data, epochs=13, batch_size=64):
|
return TensorDataset(input_ids, attention_mask, labels), sampler
|
||||||
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
|
|
||||||
val_dataloader = DataLoader(val_data, batch_size=batch_size)
|
|
||||||
|
|
||||||
optimizer = AdamW(self.model.parameters(), lr=2e-5)
|
def train(self, train_data, val_data, epochs=5, batch_size=32):
|
||||||
|
train_dataset, train_sampler = train_data
|
||||||
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
|
||||||
|
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
|
optimizer = AdamW(self.model.parameters(), lr=2e-5, eps=1e-8)
|
||||||
|
total_steps = len(train_dataloader) * epochs
|
||||||
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
@ -71,24 +61,28 @@ class FakeNewsModelTrainer:
|
||||||
for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{epochs}'):
|
for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{epochs}'):
|
||||||
input_ids, attention_mask, labels = [b.to(self.device) for b in batch]
|
input_ids, attention_mask, labels = [b.to(self.device) for b in batch]
|
||||||
|
|
||||||
|
self.model.zero_grad()
|
||||||
outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
|
outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
scheduler.step()
|
||||||
|
|
||||||
avg_train_loss = total_loss / len(train_dataloader)
|
avg_train_loss = total_loss / len(train_dataloader)
|
||||||
print(f'Average training loss: {avg_train_loss:.4f}')
|
print(f'Average training loss: {avg_train_loss:.4f}')
|
||||||
|
|
||||||
val_accuracy = self.evaluate(val_dataloader)
|
val_accuracy, val_report = self.evaluate(val_dataloader)
|
||||||
print(f'Validation accuracy: {val_accuracy:.4f}')
|
print(f'Validation accuracy: {val_accuracy:.4f}')
|
||||||
|
print('Validation Classification Report:')
|
||||||
|
print(val_report)
|
||||||
|
|
||||||
def evaluate(self, dataloader):
|
def evaluate(self, dataloader):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
correct_predictions = 0
|
predictions = []
|
||||||
total_predictions = 0
|
true_labels = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
|
@ -97,27 +91,67 @@ class FakeNewsModelTrainer:
|
||||||
outputs = self.model(input_ids, attention_mask=attention_mask)
|
outputs = self.model(input_ids, attention_mask=attention_mask)
|
||||||
_, preds = torch.max(outputs.logits, dim=1)
|
_, preds = torch.max(outputs.logits, dim=1)
|
||||||
|
|
||||||
correct_predictions += torch.sum(preds == labels)
|
predictions.extend(preds.cpu().tolist())
|
||||||
total_predictions += labels.shape[0]
|
true_labels.extend(labels.cpu().tolist())
|
||||||
|
|
||||||
return correct_predictions.float() / total_predictions
|
accuracy = sum(1 for p, t in zip(predictions, true_labels) if p == t) / len(true_labels)
|
||||||
|
report = classification_report(true_labels, predictions, target_names=['Fake', 'Real'])
|
||||||
|
|
||||||
|
print('Confusion Matrix:')
|
||||||
|
print(confusion_matrix(true_labels, predictions))
|
||||||
|
|
||||||
|
return accuracy, report
|
||||||
|
|
||||||
def save_model(self, path):
|
def save_model(self, path):
|
||||||
self.model.save_pretrained(path)
|
self.model.save_pretrained(path)
|
||||||
self.tokenizer.save_pretrained(path)
|
self.tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
|
class FakeNewsInference:
|
||||||
|
def __init__(self, model_path):
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
||||||
|
self.model = BertForSequenceClassification.from_pretrained(model_path)
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def predict(self, text):
|
||||||
|
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
probabilities = torch.softmax(outputs.logits, dim=1)
|
||||||
|
prediction = torch.argmax(probabilities, dim=1).item()
|
||||||
|
|
||||||
|
return prediction, probabilities[0][prediction].item()
|
||||||
|
|
||||||
|
# Usage example
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# Load and preprocess the data
|
||||||
df = pq.read_table('dataset.parquet').to_pandas()
|
df = pq.read_table('dataset.parquet').to_pandas()
|
||||||
|
|
||||||
# Split the data
|
# Split the data
|
||||||
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42)
|
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
|
||||||
|
|
||||||
# Initialize and train the model
|
# Initialize and train the model
|
||||||
trainer = FakeNewsModelTrainer(size_factor=0.5)
|
trainer = FakeNewsModelTrainer()
|
||||||
train_data = trainer.prepare_data(train_df)
|
train_data = trainer.prepare_data(train_df)
|
||||||
val_data = trainer.prepare_data(val_df)
|
val_data = trainer.prepare_data(val_df)[0]
|
||||||
trainer.train(train_data, val_data)
|
trainer.train(train_data, val_data)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
trainer.save_model('VeriMindSmall')
|
trainer.save_model('VeriMind')
|
||||||
|
|
||||||
|
# Inference example
|
||||||
|
inference = FakeNewsInference('fake_news_detector_model')
|
||||||
|
sample_texts = [
|
||||||
|
"Breaking news: Scientists discover new planet in solar system",
|
||||||
|
"Celebrity secretly lizard person, unnamed sources claim",
|
||||||
|
"New study shows benefits of regular exercise",
|
||||||
|
"Government admits to hiding alien life, whistleblower reveals"
|
||||||
|
]
|
||||||
|
for text in sample_texts:
|
||||||
|
prediction, confidence = inference.predict(text)
|
||||||
|
print(f"Text: {text}")
|
||||||
|
print(f"Prediction: {'Real' if prediction == 1 else 'Fake'}")
|
||||||
|
print(f"Confidence: {confidence:.4f}\n")
|
Reference in New Issue