test_model #13

Merged
Fabel merged 21 commits from test_model into develop 2024-09-03 08:53:54 +00:00
1 changed files with 51 additions and 0 deletions
Showing only changes of commit 24b6a2a8a7 - Show all commits

51
src/model/translate.py Normal file
View File

@ -0,0 +1,51 @@
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm
# Load the CSV file
df = pd.read_csv('/root/schule/WELFake_Dataset.csv')
# Take a 10% sample
sample_size = int(len(df) * 0.1)
df_sample = df.sample(n=sample_size, random_state=42)
# Load the translation model
model_name = "Helsinki-NLP/opus-mt-en-de"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Function to translate text
def translate(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
translated = model.generate(**inputs)
return tokenizer.decode(translated[0], skip_special_tokens=True)
# Translate 'text' and 'title' columns
tqdm.pandas()
df_sample['title_de'] = df_sample['title'].progress_apply(translate)
df_sample['text_de'] = df_sample['text'].progress_apply(translate)
# Calculate the new serial numbers
max_serial = df['Serial'].max()
df_sample['Serial_de'] = df_sample['Serial'].apply(lambda x: x + max_serial + 1)
# Create new rows with translated content
df_translated = df_sample.copy()
df_translated['Serial'] = df_translated['Serial_de']
df_translated['title'] = df_translated['title_de']
df_translated['text'] = df_translated['text_de']
# Drop the temporary columns
df_translated = df_translated.drop(['Serial_de', 'title_de', 'text_de'], axis=1)
# Combine original and translated DataFrames
df_combined = pd.concat([df, df_translated], ignore_index=True)
# Sort by Serial number
df_combined = df_combined.sort_values('Serial').reset_index(drop=True)
# Save as parquet
df_combined.to_parquet('combined_with_translations.parquet', index=False)
print("Translation, combination, and saving completed.")