test_model #13

Merged
Fabel merged 21 commits from test_model into develop 2024-09-03 08:53:54 +00:00
1 changed files with 24 additions and 14 deletions
Showing only changes of commit 55a50276fa - Show all commits

View File

@ -4,10 +4,22 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm
# Load the CSV file
df = pd.read_csv('/root/schule/WELFake_Dataset.csv')
file_path = '/root/schule/WELFake_Dataset.csv'
try:
df = pd.read_csv(file_path)
except FileNotFoundError:
print(f"File not found: {file_path}")
exit(1)
# Take a 10% sample
sample_size = int(len(df) * 0.1)
print("Columns in the DataFrame:", df.columns)
# Ensure the '#' column exists
if '#' not in df.columns:
print("'#' column not found. Please check your CSV file.")
exit(1)
# Take a sample of 10 entries
sample_size = 10
df_sample = df.sample(n=sample_size, random_state=42)
# Load the translation model
@ -18,38 +30,36 @@ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Function to translate text
def translate(text):
if pd.isna(text) or text == '':
return '' # Return an empty string for NaN or empty string inputs
return ''
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
translated = model.generate(**inputs)
return tokenizer.decode(translated[0], skip_special_tokens=True)
# Translate 'text' and 'title' columns
tqdm.pandas()
df_sample['title_de'] = df_sample['title'].fillna('').progress_apply(translate)
df_sample['text_de'] = df_sample['text'].fillna('').progress_apply(translate)
# Calculate the new serial numbers
max_serial = df['Serial'].max()
df_sample['Serial_de'] = df_sample['Serial'].apply(lambda x: x + max_serial + 1)
max_serial = df['#'].max()
df_sample['#_de'] = df_sample['#'].apply(lambda x: x + max_serial + 1)
# Create new rows with translated content
df_translated = df_sample.copy()
df_translated['Serial'] = df_translated['Serial_de']
df_translated['#'] = df_translated['#_de']
df_translated['title'] = df_translated['title_de']
df_translated['text'] = df_translated['text_de']
# Drop the temporary columns
df_translated = df_translated.drop(['Serial_de', 'title_de', 'text_de'], axis=1)
df_translated = df_translated.drop(['#_de', 'title_de', 'text_de'], axis=1)
# Combine original and translated DataFrames
df_combined = pd.concat([df, df_translated], ignore_index=True)
# Sort by Serial number
df_combined = df_combined.sort_values('Serial').reset_index(drop=True)
# Sort by '#' (serial) number
df_combined = df_combined.sort_values('#').reset_index(drop=True)
# Save as parquet
df_combined.to_parquet('combined_with_translations.parquet', index=False)
df_combined.to_parquet('combined_with_translations_10_samples.parquet', index=False)
print("Translation, combination, and saving completed.")
print("Translation, combination, and saving completed.")