Merge pull request 'updated the logging process' (#16) from feat/checkpoints into main
Reviewed-on: #16
This commit is contained in:
commit
9b82024a40
|
@ -60,7 +60,7 @@ class aiuNNTrainer:
|
||||||
self.last_22_date = None
|
self.last_22_date = None
|
||||||
self.recent_checkpoints = []
|
self.recent_checkpoints = []
|
||||||
self.current_epoch = 0
|
self.current_epoch = 0
|
||||||
|
|
||||||
|
|
||||||
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
||||||
"""
|
"""
|
||||||
|
@ -203,7 +203,8 @@ class aiuNNTrainer:
|
||||||
"""Finetune the upscaler model"""
|
"""Finetune the upscaler model"""
|
||||||
if self.data_loader is None:
|
if self.data_loader is None:
|
||||||
raise ValueError("Data not loaded. Call load_data first.")
|
raise ValueError("Data not loaded. Call load_data first.")
|
||||||
|
# setup logging
|
||||||
|
self._setup_logging(output_path=output_path)
|
||||||
# Setup optimizer and directories
|
# Setup optimizer and directories
|
||||||
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||||
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
|
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
|
||||||
|
@ -215,7 +216,7 @@ class aiuNNTrainer:
|
||||||
writer = csv.writer(file)
|
writer = csv.writer(file)
|
||||||
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
|
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
|
||||||
writer.writerow(header)
|
writer.writerow(header)
|
||||||
|
|
||||||
# Load existing checkpoint if available
|
# Load existing checkpoint if available
|
||||||
checkpoint_info = self.load_checkpoint()
|
checkpoint_info = self.load_checkpoint()
|
||||||
start_epoch = checkpoint_info[0] if checkpoint_info else 0
|
start_epoch = checkpoint_info[0] if checkpoint_info else 0
|
||||||
|
|
Loading…
Reference in New Issue