diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index a0e67b2..e616bd0 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -62,29 +62,27 @@ class aiuNNDataset(torch.utils.data.Dataset): class Upscaler(nn.Module): """ - Wraps the base model to perform upsampling and a final convolution. + Transforms the base model's final feature map using a transposed convolution. The base model produces a feature map of size 512x512. - We then upsample by a factor of 2 (to get 1024x1024) - and use a convolution to map the hidden features to 3 output channels. + This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features + to the output channels using a single ConvTranspose2d layer. """ def __init__(self, base_model: AIIABase): super(Upscaler, self).__init__() self.base_model = base_model - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) - self.final_conv = nn.Conv2d( - base_model.config.hidden_size, - base_model.config.num_channels, - kernel_size=3, - padding=1 + # Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer. + self.last_transform = nn.ConvTranspose2d( + in_channels=base_model.config.hidden_size, + out_channels=base_model.config.num_channels, + kernel_size=base_model.config.kernel_size, + stride=2, + padding=1, + output_padding=1 ) def forward(self, x): - # Get the feature maps from the base model (expected shape: [B, 512, 512, 512]) features = self.base_model(x) - # Upsample the features to match high resolution (1024x1024) - upsampled = self.upsample(features) - # Convert from hidden features to output channels - return self.final_conv(upsampled) + return self.last_transform(features) def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False): # Load and concatenate datasets. @@ -134,7 +132,6 @@ def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=1 high_res = batch['high_res'].to(device) with autocast(device_type="cuda"): if use_checkpoint: - # Use checkpointing if requested. low_res = batch['low_res'].to(device).requires_grad_() features = checkpoint(lambda x: model(x), low_res) else: @@ -178,13 +175,13 @@ def main(): ACCUMULATION_STEPS = 8 USE_CHECKPOINT = False - # Load the base model using the config values (hidden_size=512, num_channels=3, etc.) + # Load the base model using the provided configuration (e.g., hidden_size=512, num_channels=3, etc.) base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") - # Wrap the base model in our Upscaler so that the output is upsampled to 1024x1024 + # Wrap the base model with our modified Upscaler that transforms its last layer. model = Upscaler(base_model) - print("Modified model architecture with upsampling wrapper:") + print("Modified model architecture with transformed final layer:") print(base_model.config) finetune_model(