corrected bertconfig init

This commit is contained in:
Falko Victor Habel 2024-09-01 11:06:08 +02:00
parent 306cd5619d
commit 5d8ec1a01f
1 changed files with 2 additions and 2 deletions

View File

@ -21,9 +21,9 @@ class FakeNewsModelTrainer:
original_config = BertConfig.from_pretrained(model_name)
# Calculate new dimensions
new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16)
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
new_hidden_size = new_num_attention_heads * max(int((original_config.hidden_size // original_config.num_attention_heads) * size_factor ** 0.5), 1)
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
# Create a new config with reduced size
config = BertConfig(