diff --git a/src/model/translate.py b/src/model/translate.py deleted file mode 100644 index 463189a..0000000 --- a/src/model/translate.py +++ /dev/null @@ -1,53 +0,0 @@ -import pandas as pd -import torch -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM -from tqdm import tqdm - -# Load the CSV file -file_path = '/root/schule/WELFake_Dataset.csv' -try: - df = pd.read_csv(file_path) -except FileNotFoundError: - print(f"File not found: {file_path}") - exit(1) - -print("Columns in the DataFrame:", df.columns) - -# Ensure the 'Unnamed: 0' column exists -if 'Unnamed: 0' not in df.columns: - print("'Unnamed: 0' column not found. Please check your CSV file.") - exit(1) - -# Take a sample of 10,000 entries -sample_size = 10000 -df_sample = df.sample(n=sample_size, random_state=42) - -# Load the translation model -model_name = "Helsinki-NLP/opus-mt-en-de" -tokenizer = AutoTokenizer.from_pretrained(model_name) -model = AutoModelForSeq2SeqLM.from_pretrained(model_name) - -# Function to translate text -def translate(text, max_length=512): - if pd.isna(text) or text == '': - return '' - # Remove special characters and limit length - text = ''.join(char for char in text if char.isalnum() or char.isspace()) - text = text[:max_length] - inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length) - translated = model.generate(**inputs) - return tokenizer.decode(translated[0], skip_special_tokens=True) - -# Translate 'text' and 'title' columns -tqdm.pandas() -df_sample['title_de'] = df_sample['title'].fillna('').progress_apply(translate) -df_sample['text_de'] = df_sample['text'].fillna('').progress_apply(lambda x: translate(x, max_length=1024)) - -# Calculate the new serial numbers -max_serial = df['Unnamed: 0'].max() -df_sample['Unnamed: 0_de'] = df_sample['Unnamed: 0'].apply(lambda x: x + max_serial + 1) - -# Create new rows with translated content -df_translated = df_sample.copy() -df_translated['Unnamed: 0'] = df_translated['Unnamed: 0_de'] -df_translated['title'] = df_translated['title_de