test_model #13
|
@ -21,9 +21,9 @@ class FakeNewsModelTrainer:
|
||||||
original_config = BertConfig.from_pretrained(model_name)
|
original_config = BertConfig.from_pretrained(model_name)
|
||||||
|
|
||||||
# Calculate new dimensions
|
# Calculate new dimensions
|
||||||
new_hidden_size = max(int(original_config.hidden_size * size_factor ** 0.5), 16)
|
|
||||||
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
|
|
||||||
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
|
new_num_attention_heads = max(int(original_config.num_attention_heads * size_factor ** 0.5), 1)
|
||||||
|
new_hidden_size = new_num_attention_heads * max(int((original_config.hidden_size // original_config.num_attention_heads) * size_factor ** 0.5), 1)
|
||||||
|
new_num_hidden_layers = max(int(original_config.num_hidden_layers * size_factor ** 0.5), 1)
|
||||||
|
|
||||||
# Create a new config with reduced size
|
# Create a new config with reduced size
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
|
|
Reference in New Issue