added validation of trainingsdata

This commit is contained in:
Falko Victor Habel 2024-08-31 08:19:47 +02:00
parent aa6c343c40
commit 7b07ad91a1
1 changed files with 16 additions and 8 deletions

View File

@ -1,5 +1,3 @@
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
@ -17,13 +15,25 @@ class FakeNewsModelTrainer:
self.model.to(self.device)
def prepare_data(self, df):
texts = df['text'].tolist()
# Combine title and text, handling potential empty values
texts = df.apply(lambda row: f"{row['title'] or ''} {row['text'] or ''}".strip(), axis=1).tolist()
labels = df['label'].tolist()
encoded_texts = self.tokenizer(texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt')
# Filter out empty texts
valid_texts = []
valid_labels = []
for text, label in zip(texts, labels):
if text.strip(): # Check if the text is not empty after stripping whitespace
valid_texts.append(text)
valid_labels.append(label)
if not valid_texts:
raise ValueError("No valid texts found in the dataset after filtering empty entries.")
encoded_texts = self.tokenizer(valid_texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt')
input_ids = encoded_texts['input_ids']
attention_mask = encoded_texts['attention_mask']
labels = torch.tensor(labels)
labels = torch.tensor(valid_labels)
return TensorDataset(input_ids, attention_mask, labels)
@ -75,15 +85,13 @@ class FakeNewsModelTrainer:
self.model.save_pretrained(path)
self.tokenizer.save_pretrained(path)
# Usage example
if __name__ == '__main__':
# Load and preprocess the data
df = pq.read_table('dataset.parquet').to_pandas()
df['text'] = df['title'] + ' ' + df['text'] # Combine title and text
# Split the data
train_df, val_df = train_test_split(df, test_size=0.3, random_state=42)
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42)
# Initialize and train the model
trainer = FakeNewsModelTrainer()