added validation of trainingsdata
This commit is contained in:
parent
aa6c343c40
commit
7b07ad91a1
|
@ -1,5 +1,3 @@
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
@ -17,13 +15,25 @@ class FakeNewsModelTrainer:
|
|||
self.model.to(self.device)
|
||||
|
||||
def prepare_data(self, df):
|
||||
texts = df['text'].tolist()
|
||||
# Combine title and text, handling potential empty values
|
||||
texts = df.apply(lambda row: f"{row['title'] or ''} {row['text'] or ''}".strip(), axis=1).tolist()
|
||||
labels = df['label'].tolist()
|
||||
|
||||
encoded_texts = self.tokenizer(texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt')
|
||||
# Filter out empty texts
|
||||
valid_texts = []
|
||||
valid_labels = []
|
||||
for text, label in zip(texts, labels):
|
||||
if text.strip(): # Check if the text is not empty after stripping whitespace
|
||||
valid_texts.append(text)
|
||||
valid_labels.append(label)
|
||||
|
||||
if not valid_texts:
|
||||
raise ValueError("No valid texts found in the dataset after filtering empty entries.")
|
||||
|
||||
encoded_texts = self.tokenizer(valid_texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt')
|
||||
input_ids = encoded_texts['input_ids']
|
||||
attention_mask = encoded_texts['attention_mask']
|
||||
labels = torch.tensor(labels)
|
||||
labels = torch.tensor(valid_labels)
|
||||
|
||||
return TensorDataset(input_ids, attention_mask, labels)
|
||||
|
||||
|
@ -75,15 +85,13 @@ class FakeNewsModelTrainer:
|
|||
self.model.save_pretrained(path)
|
||||
self.tokenizer.save_pretrained(path)
|
||||
|
||||
|
||||
# Usage example
|
||||
if __name__ == '__main__':
|
||||
# Load and preprocess the data
|
||||
df = pq.read_table('dataset.parquet').to_pandas()
|
||||
df['text'] = df['title'] + ' ' + df['text'] # Combine title and text
|
||||
|
||||
# Split the data
|
||||
train_df, val_df = train_test_split(df, test_size=0.3, random_state=42)
|
||||
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42)
|
||||
|
||||
# Initialize and train the model
|
||||
trainer = FakeNewsModelTrainer()
|
||||
|
|
Reference in New Issue