Merge pull request 'bugfix' (#14) from feat/fix_saving into develop
Reviewed-on: #14
This commit is contained in:
commit
ee58b085f5
|
@ -7,7 +7,7 @@ config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
|
|
||||||
# Initialize pretrainer with the model
|
# Initialize pretrainer with the model
|
||||||
pretrainer = Pretrainer(model, learning_rate=1e-4)
|
pretrainer = Pretrainer(model, learning_rate=1e-4, config=config)
|
||||||
|
|
||||||
# List of dataset paths
|
# List of dataset paths
|
||||||
dataset_paths = [
|
dataset_paths = [
|
||||||
|
|
|
@ -2,4 +2,7 @@ torch>=2.5.0
|
||||||
numpy
|
numpy
|
||||||
tqdm
|
tqdm
|
||||||
pytest
|
pytest
|
||||||
pillow
|
pillow
|
||||||
|
pandas
|
||||||
|
torchvision
|
||||||
|
pyarrow
|
|
@ -130,7 +130,6 @@ class AIIABaseShared(AIIA):
|
||||||
self.max_pool = nn.MaxPool2d(
|
self.max_pool = nn.MaxPool2d(
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
Loading…
Reference in New Issue