added validation of trainingsdata
This commit is contained in:
parent
aa6c343c40
commit
7b07ad91a1
|
@ -1,5 +1,3 @@
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
|
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
@ -17,13 +15,25 @@ class FakeNewsModelTrainer:
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
def prepare_data(self, df):
|
def prepare_data(self, df):
|
||||||
texts = df['text'].tolist()
|
# Combine title and text, handling potential empty values
|
||||||
|
texts = df.apply(lambda row: f"{row['title'] or ''} {row['text'] or ''}".strip(), axis=1).tolist()
|
||||||
labels = df['label'].tolist()
|
labels = df['label'].tolist()
|
||||||
|
|
||||||
encoded_texts = self.tokenizer(texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt')
|
# Filter out empty texts
|
||||||
|
valid_texts = []
|
||||||
|
valid_labels = []
|
||||||
|
for text, label in zip(texts, labels):
|
||||||
|
if text.strip(): # Check if the text is not empty after stripping whitespace
|
||||||
|
valid_texts.append(text)
|
||||||
|
valid_labels.append(label)
|
||||||
|
|
||||||
|
if not valid_texts:
|
||||||
|
raise ValueError("No valid texts found in the dataset after filtering empty entries.")
|
||||||
|
|
||||||
|
encoded_texts = self.tokenizer(valid_texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt')
|
||||||
input_ids = encoded_texts['input_ids']
|
input_ids = encoded_texts['input_ids']
|
||||||
attention_mask = encoded_texts['attention_mask']
|
attention_mask = encoded_texts['attention_mask']
|
||||||
labels = torch.tensor(labels)
|
labels = torch.tensor(valid_labels)
|
||||||
|
|
||||||
return TensorDataset(input_ids, attention_mask, labels)
|
return TensorDataset(input_ids, attention_mask, labels)
|
||||||
|
|
||||||
|
@ -75,15 +85,13 @@ class FakeNewsModelTrainer:
|
||||||
self.model.save_pretrained(path)
|
self.model.save_pretrained(path)
|
||||||
self.tokenizer.save_pretrained(path)
|
self.tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
|
|
||||||
# Usage example
|
# Usage example
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Load and preprocess the data
|
# Load and preprocess the data
|
||||||
df = pq.read_table('dataset.parquet').to_pandas()
|
df = pq.read_table('dataset.parquet').to_pandas()
|
||||||
df['text'] = df['title'] + ' ' + df['text'] # Combine title and text
|
|
||||||
|
|
||||||
# Split the data
|
# Split the data
|
||||||
train_df, val_df = train_test_split(df, test_size=0.3, random_state=42)
|
train_df, val_df = train_test_split(df, test_size=0.35, random_state=42)
|
||||||
|
|
||||||
# Initialize and train the model
|
# Initialize and train the model
|
||||||
trainer = FakeNewsModelTrainer()
|
trainer = FakeNewsModelTrainer()
|
||||||
|
|
Reference in New Issue