Compare commits

..

No commits in common. "4e7d0d806f2a692ff293442cae9a233c9109fd86" and "40c2b6dd501a8f4dd3a7fca94118e968c900e9d9" have entirely different histories.

4 changed files with 5 additions and 6 deletions

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project]
name = "aiia"
version = "0.1.3"
version = "0.1.2"
description = "AIIA Deep Learning Model Implementation"
readme = "README.md"
authors = [

View File

@ -1,6 +1,6 @@
[metadata]
name = aiia
version = 0.1.3
version = 0.1.2
author = Falko Habel
author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation

View File

@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.1.3"
__version__ = "0.1.2"

View File

@ -6,7 +6,7 @@ from tqdm import tqdm
from ..model.Model import AIIA
from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader
import os
class ProjectionHead(nn.Module):
def __init__(self, hidden_size):
@ -189,8 +189,7 @@ class Pretrainer:
self.model.save(output_path)
print("Best model saved!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
self.save_losses('losses.csv')
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss."""