bugfix
This commit is contained in:
parent
06a2e012e8
commit
66762775e9
|
@ -7,7 +7,7 @@ config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
|||
model = AIIABase(config)
|
||||
|
||||
# Initialize pretrainer with the model
|
||||
pretrainer = Pretrainer(model, learning_rate=1e-4)
|
||||
pretrainer = Pretrainer(model, learning_rate=1e-4, config=config)
|
||||
|
||||
# List of dataset paths
|
||||
dataset_paths = [
|
||||
|
|
|
@ -2,4 +2,7 @@ torch>=2.5.0
|
|||
numpy
|
||||
tqdm
|
||||
pytest
|
||||
pillow
|
||||
pillow
|
||||
pandas
|
||||
torchvision
|
||||
pyarrow
|
|
@ -130,7 +130,6 @@ class AIIABaseShared(AIIA):
|
|||
self.max_pool = nn.MaxPool2d(
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
Loading…
Reference in New Issue