translation for model added
This commit is contained in:
parent
37d60e436f
commit
24b6a2a8a7
|
@ -0,0 +1,51 @@
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Load the CSV file
|
||||||
|
df = pd.read_csv('/root/schule/WELFake_Dataset.csv')
|
||||||
|
|
||||||
|
# Take a 10% sample
|
||||||
|
sample_size = int(len(df) * 0.1)
|
||||||
|
df_sample = df.sample(n=sample_size, random_state=42)
|
||||||
|
|
||||||
|
# Load the translation model
|
||||||
|
model_name = "Helsinki-NLP/opus-mt-en-de"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# Function to translate text
|
||||||
|
def translate(text):
|
||||||
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
||||||
|
translated = model.generate(**inputs)
|
||||||
|
return tokenizer.decode(translated[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Translate 'text' and 'title' columns
|
||||||
|
tqdm.pandas()
|
||||||
|
df_sample['title_de'] = df_sample['title'].progress_apply(translate)
|
||||||
|
df_sample['text_de'] = df_sample['text'].progress_apply(translate)
|
||||||
|
|
||||||
|
# Calculate the new serial numbers
|
||||||
|
max_serial = df['Serial'].max()
|
||||||
|
df_sample['Serial_de'] = df_sample['Serial'].apply(lambda x: x + max_serial + 1)
|
||||||
|
|
||||||
|
# Create new rows with translated content
|
||||||
|
df_translated = df_sample.copy()
|
||||||
|
df_translated['Serial'] = df_translated['Serial_de']
|
||||||
|
df_translated['title'] = df_translated['title_de']
|
||||||
|
df_translated['text'] = df_translated['text_de']
|
||||||
|
|
||||||
|
# Drop the temporary columns
|
||||||
|
df_translated = df_translated.drop(['Serial_de', 'title_de', 'text_de'], axis=1)
|
||||||
|
|
||||||
|
# Combine original and translated DataFrames
|
||||||
|
df_combined = pd.concat([df, df_translated], ignore_index=True)
|
||||||
|
|
||||||
|
# Sort by Serial number
|
||||||
|
df_combined = df_combined.sort_values('Serial').reset_index(drop=True)
|
||||||
|
|
||||||
|
# Save as parquet
|
||||||
|
df_combined.to_parquet('combined_with_translations.parquet', index=False)
|
||||||
|
|
||||||
|
print("Translation, combination, and saving completed.")
|
Reference in New Issue