diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index d3279ba..a0e67b2 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -174,9 +174,9 @@ def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=1 return model def main(): - BATCH_SIZE = 2 + BATCH_SIZE = 1 ACCUMULATION_STEPS = 8 - USE_CHECKPOINT = True + USE_CHECKPOINT = False # Load the base model using the config values (hidden_size=512, num_channels=3, etc.) base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")