Merge pull request 'bug fixing realted to storing losses' (#16) from bug_fixes into main

Reviewed-on: #16
This commit is contained in:
Falko Victor Habel 2025-03-03 12:33:49 +00:00
commit 4e7d0d806f
4 changed files with 6 additions and 5 deletions

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project] [project]
name = "aiia" name = "aiia"
version = "0.1.2" version = "0.1.3"
description = "AIIA Deep Learning Model Implementation" description = "AIIA Deep Learning Model Implementation"
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = aiia name = aiia
version = 0.1.2 version = 0.1.3
author = Falko Habel author = Falko Habel
author_email = falko.habel@gmx.de author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation description = AIIA deep learning model implementation

View File

@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.1.2" __version__ = "0.1.3"

View File

@ -6,7 +6,7 @@ from tqdm import tqdm
from ..model.Model import AIIA from ..model.Model import AIIA
from ..model.config import AIIAConfig from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
import os
class ProjectionHead(nn.Module): class ProjectionHead(nn.Module):
def __init__(self, hidden_size): def __init__(self, hidden_size):
@ -189,7 +189,8 @@ class Pretrainer:
self.model.save(output_path) self.model.save(output_path)
print("Best model saved!") print("Best model saved!")
self.save_losses('losses.csv') losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
def _validate(self, val_loader, criterion_denoise, criterion_rotate): def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss.""" """Perform validation and return average validation loss."""