From 55a50276fa730185a643fa0403e91c4b46b42de9 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 30 Aug 2024 08:02:51 +0200 Subject: [PATCH] better integration --- src/model/translate.py | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/model/translate.py b/src/model/translate.py index e861ff1..c1da695 100644 --- a/src/model/translate.py +++ b/src/model/translate.py @@ -4,10 +4,22 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from tqdm import tqdm # Load the CSV file -df = pd.read_csv('/root/schule/WELFake_Dataset.csv') +file_path = '/root/schule/WELFake_Dataset.csv' +try: + df = pd.read_csv(file_path) +except FileNotFoundError: + print(f"File not found: {file_path}") + exit(1) -# Take a 10% sample -sample_size = int(len(df) * 0.1) +print("Columns in the DataFrame:", df.columns) + +# Ensure the '#' column exists +if '#' not in df.columns: + print("'#' column not found. Please check your CSV file.") + exit(1) + +# Take a sample of 10 entries +sample_size = 10 df_sample = df.sample(n=sample_size, random_state=42) # Load the translation model @@ -18,38 +30,36 @@ model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Function to translate text def translate(text): if pd.isna(text) or text == '': - return '' # Return an empty string for NaN or empty string inputs + return '' inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) translated = model.generate(**inputs) return tokenizer.decode(translated[0], skip_special_tokens=True) - # Translate 'text' and 'title' columns tqdm.pandas() df_sample['title_de'] = df_sample['title'].fillna('').progress_apply(translate) df_sample['text_de'] = df_sample['text'].fillna('').progress_apply(translate) - # Calculate the new serial numbers -max_serial = df['Serial'].max() -df_sample['Serial_de'] = df_sample['Serial'].apply(lambda x: x + max_serial + 1) +max_serial = df['#'].max() +df_sample['#_de'] = df_sample['#'].apply(lambda x: x + max_serial + 1) # Create new rows with translated content df_translated = df_sample.copy() -df_translated['Serial'] = df_translated['Serial_de'] +df_translated['#'] = df_translated['#_de'] df_translated['title'] = df_translated['title_de'] df_translated['text'] = df_translated['text_de'] # Drop the temporary columns -df_translated = df_translated.drop(['Serial_de', 'title_de', 'text_de'], axis=1) +df_translated = df_translated.drop(['#_de', 'title_de', 'text_de'], axis=1) # Combine original and translated DataFrames df_combined = pd.concat([df, df_translated], ignore_index=True) -# Sort by Serial number -df_combined = df_combined.sort_values('Serial').reset_index(drop=True) +# Sort by '#' (serial) number +df_combined = df_combined.sort_values('#').reset_index(drop=True) # Save as parquet -df_combined.to_parquet('combined_with_translations.parquet', index=False) +df_combined.to_parquet('combined_with_translations_10_samples.parquet', index=False) -print("Translation, combination, and saving completed.") +print("Translation, combination, and saving completed.") \ No newline at end of file