Merge pull request 'bug fixing realted to storing losses' (#16) from bug_fixes into main
Reviewed-on: #16
This commit is contained in:
commit
4e7d0d806f
|
@ -10,7 +10,7 @@ include = '\.pyi?$'
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "aiia"
|
name = "aiia"
|
||||||
version = "0.1.2"
|
version = "0.1.3"
|
||||||
description = "AIIA Deep Learning Model Implementation"
|
description = "AIIA Deep Learning Model Implementation"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[metadata]
|
[metadata]
|
||||||
name = aiia
|
name = aiia
|
||||||
version = 0.1.2
|
version = 0.1.3
|
||||||
author = Falko Habel
|
author = Falko Habel
|
||||||
author_email = falko.habel@gmx.de
|
author_email = falko.habel@gmx.de
|
||||||
description = AIIA deep learning model implementation
|
description = AIIA deep learning model implementation
|
||||||
|
|
|
@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
|
||||||
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.2"
|
__version__ = "0.1.3"
|
||||||
|
|
|
@ -6,7 +6,7 @@ from tqdm import tqdm
|
||||||
from ..model.Model import AIIA
|
from ..model.Model import AIIA
|
||||||
from ..model.config import AIIAConfig
|
from ..model.config import AIIAConfig
|
||||||
from ..data.DataLoader import AIIADataLoader
|
from ..data.DataLoader import AIIADataLoader
|
||||||
|
import os
|
||||||
|
|
||||||
class ProjectionHead(nn.Module):
|
class ProjectionHead(nn.Module):
|
||||||
def __init__(self, hidden_size):
|
def __init__(self, hidden_size):
|
||||||
|
@ -189,7 +189,8 @@ class Pretrainer:
|
||||||
self.model.save(output_path)
|
self.model.save(output_path)
|
||||||
print("Best model saved!")
|
print("Best model saved!")
|
||||||
|
|
||||||
self.save_losses('losses.csv')
|
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
||||||
|
self.save_losses(losses_path)
|
||||||
|
|
||||||
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
||||||
"""Perform validation and return average validation loss."""
|
"""Perform validation and return average validation loss."""
|
||||||
|
|
Loading…
Reference in New Issue