diff --git a/src/model/train.py b/src/model/train.py index 14f00dc..97fa4da 100644 --- a/src/model/train.py +++ b/src/model/train.py @@ -1,5 +1,3 @@ -import pandas as pd -import numpy as np from sklearn.model_selection import train_test_split from transformers import BertTokenizer, BertForSequenceClassification, AdamW from torch.utils.data import DataLoader, TensorDataset @@ -17,13 +15,25 @@ class FakeNewsModelTrainer: self.model.to(self.device) def prepare_data(self, df): - texts = df['text'].tolist() + # Combine title and text, handling potential empty values + texts = df.apply(lambda row: f"{row['title'] or ''} {row['text'] or ''}".strip(), axis=1).tolist() labels = df['label'].tolist() - encoded_texts = self.tokenizer(texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt') + # Filter out empty texts + valid_texts = [] + valid_labels = [] + for text, label in zip(texts, labels): + if text.strip(): # Check if the text is not empty after stripping whitespace + valid_texts.append(text) + valid_labels.append(label) + + if not valid_texts: + raise ValueError("No valid texts found in the dataset after filtering empty entries.") + + encoded_texts = self.tokenizer(valid_texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt') input_ids = encoded_texts['input_ids'] attention_mask = encoded_texts['attention_mask'] - labels = torch.tensor(labels) + labels = torch.tensor(valid_labels) return TensorDataset(input_ids, attention_mask, labels) @@ -75,15 +85,13 @@ class FakeNewsModelTrainer: self.model.save_pretrained(path) self.tokenizer.save_pretrained(path) - # Usage example if __name__ == '__main__': # Load and preprocess the data df = pq.read_table('dataset.parquet').to_pandas() - df['text'] = df['title'] + ' ' + df['text'] # Combine title and text # Split the data - train_df, val_df = train_test_split(df, test_size=0.3, random_state=42) + train_df, val_df = train_test_split(df, test_size=0.35, random_state=42) # Initialize and train the model trainer = FakeNewsModelTrainer()