training class added
This commit is contained in:
parent
9592ab8140
commit
cbfcad6088
src/model
|
@ -0,0 +1,95 @@
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
|
class FakeNewsModelTrainer:
|
||||||
|
def __init__(self, model_name='google-bert/bert-base-multilingual-cased', max_length=512):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_length = max_length
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
||||||
|
self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
def prepare_data(self, df):
|
||||||
|
texts = df['text'].tolist()
|
||||||
|
labels = df['label'].tolist()
|
||||||
|
|
||||||
|
encoded_texts = self.tokenizer(texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt')
|
||||||
|
input_ids = encoded_texts['input_ids']
|
||||||
|
attention_mask = encoded_texts['attention_mask']
|
||||||
|
labels = torch.tensor(labels)
|
||||||
|
|
||||||
|
return TensorDataset(input_ids, attention_mask, labels)
|
||||||
|
|
||||||
|
def train(self, train_data, val_data, epochs=3, batch_size=16):
|
||||||
|
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
|
||||||
|
val_dataloader = DataLoader(val_data, batch_size=batch_size)
|
||||||
|
|
||||||
|
optimizer = AdamW(self.model.parameters(), lr=2e-5)
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0
|
||||||
|
|
||||||
|
for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{epochs}'):
|
||||||
|
input_ids, attention_mask, labels = [b.to(self.device) for b in batch]
|
||||||
|
|
||||||
|
outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
|
||||||
|
loss = outputs.loss
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
avg_train_loss = total_loss / len(train_dataloader)
|
||||||
|
print(f'Average training loss: {avg_train_loss:.4f}')
|
||||||
|
|
||||||
|
val_accuracy = self.evaluate(val_dataloader)
|
||||||
|
print(f'Validation accuracy: {val_accuracy:.4f}')
|
||||||
|
|
||||||
|
def evaluate(self, dataloader):
|
||||||
|
self.model.eval()
|
||||||
|
correct_predictions = 0
|
||||||
|
total_predictions = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in dataloader:
|
||||||
|
input_ids, attention_mask, labels = [b.to(self.device) for b in batch]
|
||||||
|
|
||||||
|
outputs = self.model(input_ids, attention_mask=attention_mask)
|
||||||
|
_, preds = torch.max(outputs.logits, dim=1)
|
||||||
|
|
||||||
|
correct_predictions += torch.sum(preds == labels)
|
||||||
|
total_predictions += labels.shape[0]
|
||||||
|
|
||||||
|
return correct_predictions.float() / total_predictions
|
||||||
|
|
||||||
|
def save_model(self, path):
|
||||||
|
self.model.save_pretrained(path)
|
||||||
|
self.tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
|
|
||||||
|
# Usage example
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load and preprocess the data
|
||||||
|
df = pq.read_table('your_dataset.parquet').to_pandas()
|
||||||
|
df['text'] = df['title'] + ' ' + df['text'] # Combine title and text
|
||||||
|
|
||||||
|
# Split the data
|
||||||
|
train_df, val_df = train_test_split(df, test_size=0.3, random_state=42)
|
||||||
|
|
||||||
|
# Initialize and train the model
|
||||||
|
trainer = FakeNewsModelTrainer()
|
||||||
|
train_data = trainer.prepare_data(train_df)
|
||||||
|
val_data = trainer.prepare_data(val_df)
|
||||||
|
trainer.train(train_data, val_data)
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
trainer.save_model('VeriMind')
|
Reference in New Issue