updated the logging process #16
|
@ -60,7 +60,7 @@ class aiuNNTrainer:
|
|||
self.last_22_date = None
|
||||
self.recent_checkpoints = []
|
||||
self.current_epoch = 0
|
||||
|
||||
|
||||
|
||||
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
||||
"""
|
||||
|
@ -203,7 +203,8 @@ class aiuNNTrainer:
|
|||
"""Finetune the upscaler model"""
|
||||
if self.data_loader is None:
|
||||
raise ValueError("Data not loaded. Call load_data first.")
|
||||
|
||||
# setup logging
|
||||
self._setup_logging(output_path=output_path)
|
||||
# Setup optimizer and directories
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
|
||||
|
@ -215,7 +216,7 @@ class aiuNNTrainer:
|
|||
writer = csv.writer(file)
|
||||
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
|
||||
writer.writerow(header)
|
||||
|
||||
|
||||
# Load existing checkpoint if available
|
||||
checkpoint_info = self.load_checkpoint()
|
||||
start_epoch = checkpoint_info[0] if checkpoint_info else 0
|
||||
|
|
Loading…
Reference in New Issue