From be74658ceb3043537f7947e4706c56472b3ff622 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 31 Jan 2025 17:30:41 +0100 Subject: [PATCH] added decoeder model --- src/aiunn/finetune.py | 76 ++++++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 26 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 50acea1..6121e22 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -99,6 +99,27 @@ class ImageDataset(Dataset): if 'high_res_stream' in locals(): high_res_stream.close() +class SuperResolutionModel(AIIA): + def __init__(self, base_model): + super(SuperResolutionModel, self).__init__() + # Use base model as encoder + self.encoder = base_model + for param in self.encoder.parameters(): + param.requires_grad = False # Freeze encoder layers + + # Add decoder layers to reconstruct image + self.decoder = nn.Sequential( + nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1) + ) + + def forward(self, x): + features = self.encoder(x) + output = self.decoder(features) + return output class FineTuner: def __init__(self, @@ -296,36 +317,39 @@ class FineTuner: print(f"Train Loss: {epoch_loss:.4f}") return epoch_loss - def _validate_epoch(self): - """ - Validate model performance - Returns: - float: Average validation loss for the epoch - """ - self.model.eval() - val_loss = 0.0 + def _train_epoch(self): + """Train model for one epoch""" + self.model.train() + running_loss = 0.0 - with torch.no_grad(): - for batch in tqdm(self.val_loader, desc="Validation"): - low_ress = batch['low_ress'].to(self.device) - high_ress = batch['high_ress'].to(self.device) + for batch in tqdm(self.train_loader, desc="Training"): + low_ress = batch['low_ress'].to(self.device) + high_ress = batch['high_ress'].to(self.device) - try: - features = self.model(low_ress) - except Exception as e: - raise RuntimeError(f"Error during validation forward pass: {str(e)}") + # Forward pass + try: + outputs = self.model(low_ress) # Now outputs are images + print("Output shape:", outputs.shape) + print("High-res shape:", high_ress.shape) + except Exception as e: + raise RuntimeError(f"Error during forward pass: {str(e)}") - # Calculate same loss combination - l1_loss = self.criterion(features, high_ress) * 0.5 - mse_loss = self.mse_criterion(features, high_ress) * 0.5 - total_loss = l1_loss + mse_loss + # Calculate loss + l1_loss = self.criterion(outputs, high_ress) * 0.5 + mse_loss = self.mse_criterion(outputs, high_ress) * 0.5 + total_loss = l1_loss + mse_loss - val_loss += total_loss.item() + # Backward pass and optimize + self.optimizer.zero_grad() + total_loss.backward() + self.optimizer.step() - self.current_val_loss = val_loss / len(self.val_loader) - self.val_losses.append(self.current_val_loss) - print(f"Validation Loss: {self.current_val_loss:.4f}") - return self.current_val_loss + running_loss += total_loss.item() + + epoch_loss = running_loss / len(self.train_loader) + self.train_lossess.append(epoch_loss) + print(f"Train Loss: {epoch_loss:.4f}") + return epoch_loss def train(self, num_epochs: int = 10): """ @@ -369,7 +393,7 @@ class FineTuner: if __name__ == "__main__": # Load your model first - model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") + model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512")) trainer = FineTuner( model=model,