From 75be3291d3d849ca3eb8be046012e505887247cb Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 14 Feb 2025 22:39:20 +0100 Subject: [PATCH] added custom upscaler model --- src/aiunn/finetune.py | 56 ++++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 774da0c..6d24bb3 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -60,7 +60,33 @@ class aiuNNDataset(torch.utils.data.Dataset): 'high_res': augmented_high['image'] } -def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False): +class Upscaler(nn.Module): + """ + Wraps the base model to perform upsampling and a final convolution. + The base model produces a feature map of size 512x512. + We then upsample by a factor of 2 (to get 1024x1024) + and use a convolution to map the hidden features to 3 output channels. + """ + def __init__(self, base_model: AIIABase): + super(Upscaler, self).__init__() + self.base_model = base_model + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + self.final_conv = nn.Conv2d( + base_model.config.hidden_size, + base_model.config.num_channels, + kernel_size=3, + padding=1 + ) + + def forward(self, x): + # Get the feature maps from the base model (expected shape: [B, 512, 512, 512]) + features = self.base_model(x) + # Upsample the features to match high resolution (1024x1024) + upsampled = self.upsample(features) + # Convert from hidden features to output channels + return self.final_conv(upsampled) + +def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False): # Load and concatenate datasets. loaded_datasets = [aiuNNDataset(d) for d in datasets] combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) @@ -93,7 +119,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac model = model.to(device) criterion = nn.MSELoss() - optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) + optimizer = torch.optim.Adam(model.parameters(), lr=model.base_model.config.learning_rate) scaler = GradScaler() best_val_loss = float('inf') @@ -108,10 +134,11 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac high_res = batch['high_res'].to(device) with autocast(device_type="cuda"): if use_checkpoint: - outputs = checkpoint(lambda x: model(x), low_res) + # Use checkpointing if requested. + features = checkpoint(lambda x: model(x), low_res) else: - outputs = model(low_res) - loss = criterion(outputs, high_res) / accumulation_steps + features = model(low_res) + loss = criterion(features, high_res) / accumulation_steps scaler.scale(loss).backward() train_loss += loss.item() * accumulation_steps if i % accumulation_steps == 0: @@ -142,7 +169,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss - model.save("best_model") + model.base_model.save("best_model") return model def main(): @@ -150,19 +177,14 @@ def main(): ACCUMULATION_STEPS = 8 USE_CHECKPOINT = True - model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") + # Load the base model using the config values (hidden_size=512, num_channels=3, etc.) + base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") - # Upsample output if a 'chunked_' attribute exists, ensuring spatial dimensions match the high resolution images. - if hasattr(model, 'chunked_'): - model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) + # Wrap the base model in our Upscaler so that the output is upsampled to 1024x1024 + model = Upscaler(base_model) - # Append a final convolutional layer using values from model.config. - # This converts the hidden feature maps (512 channels) to the 3 required output channels. - final_conv = nn.Conv2d(model.config.hidden_size, model.config.num_channels, kernel_size=3, padding=1) - model.add_module('final_layer', final_conv) - - print("Modified model architecture:") - print(model.config) + print("Modified model architecture with upsampling wrapper:") + print(base_model.config) finetune_model( model=model,