From 24b6a2a8a7969c76505b55303c8670ddc87d2921 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Thu, 29 Aug 2024 11:11:59 +0200 Subject: [PATCH] translation for model added --- src/model/translate.py | 51 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 src/model/translate.py diff --git a/src/model/translate.py b/src/model/translate.py new file mode 100644 index 0000000..4c58aec --- /dev/null +++ b/src/model/translate.py @@ -0,0 +1,51 @@ +import pandas as pd +import torch +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +from tqdm import tqdm + +# Load the CSV file +df = pd.read_csv('/root/schule/WELFake_Dataset.csv') + +# Take a 10% sample +sample_size = int(len(df) * 0.1) +df_sample = df.sample(n=sample_size, random_state=42) + +# Load the translation model +model_name = "Helsinki-NLP/opus-mt-en-de" +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + +# Function to translate text +def translate(text): + inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) + translated = model.generate(**inputs) + return tokenizer.decode(translated[0], skip_special_tokens=True) + +# Translate 'text' and 'title' columns +tqdm.pandas() +df_sample['title_de'] = df_sample['title'].progress_apply(translate) +df_sample['text_de'] = df_sample['text'].progress_apply(translate) + +# Calculate the new serial numbers +max_serial = df['Serial'].max() +df_sample['Serial_de'] = df_sample['Serial'].apply(lambda x: x + max_serial + 1) + +# Create new rows with translated content +df_translated = df_sample.copy() +df_translated['Serial'] = df_translated['Serial_de'] +df_translated['title'] = df_translated['title_de'] +df_translated['text'] = df_translated['text_de'] + +# Drop the temporary columns +df_translated = df_translated.drop(['Serial_de', 'title_de', 'text_de'], axis=1) + +# Combine original and translated DataFrames +df_combined = pd.concat([df, df_translated], ignore_index=True) + +# Sort by Serial number +df_combined = df_combined.sort_values('Serial').reset_index(drop=True) + +# Save as parquet +df_combined.to_parquet('combined_with_translations.parquet', index=False) + +print("Translation, combination, and saving completed.")