updated the logging process #16

Merged
Fabel merged 1 commits from feat/checkpoints into main 2025-04-23 12:32:47 +00:00
1 changed files with 4 additions and 3 deletions

View File

@ -60,7 +60,7 @@ class aiuNNTrainer:
self.last_22_date = None self.last_22_date = None
self.recent_checkpoints = [] self.recent_checkpoints = []
self.current_epoch = 0 self.current_epoch = 0
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None): def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
""" """
@ -203,7 +203,8 @@ class aiuNNTrainer:
"""Finetune the upscaler model""" """Finetune the upscaler model"""
if self.data_loader is None: if self.data_loader is None:
raise ValueError("Data not loaded. Call load_data first.") raise ValueError("Data not loaded. Call load_data first.")
# setup logging
self._setup_logging(output_path=output_path)
# Setup optimizer and directories # Setup optimizer and directories
self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.checkpoint_dir = os.path.join(output_path, "checkpoints") self.checkpoint_dir = os.path.join(output_path, "checkpoints")
@ -215,7 +216,7 @@ class aiuNNTrainer:
writer = csv.writer(file) writer = csv.writer(file)
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss'] header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
writer.writerow(header) writer.writerow(header)
# Load existing checkpoint if available # Load existing checkpoint if available
checkpoint_info = self.load_checkpoint() checkpoint_info = self.load_checkpoint()
start_epoch = checkpoint_info[0] if checkpoint_info else 0 start_epoch = checkpoint_info[0] if checkpoint_info else 0