diff --git a/src/model/translate.py b/src/model/translate.py index 4c58aec..e861ff1 100644 --- a/src/model/translate.py +++ b/src/model/translate.py @@ -17,14 +17,18 @@ model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Function to translate text def translate(text): + if pd.isna(text) or text == '': + return '' # Return an empty string for NaN or empty string inputs inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) translated = model.generate(**inputs) return tokenizer.decode(translated[0], skip_special_tokens=True) + # Translate 'text' and 'title' columns tqdm.pandas() -df_sample['title_de'] = df_sample['title'].progress_apply(translate) -df_sample['text_de'] = df_sample['text'].progress_apply(translate) +df_sample['title_de'] = df_sample['title'].fillna('').progress_apply(translate) +df_sample['text_de'] = df_sample['text'].fillna('').progress_apply(translate) + # Calculate the new serial numbers max_serial = df['Serial'].max()