updated script

This commit is contained in:
Falko Victor Habel 2025-02-15 13:27:30 +01:00
parent 619e17c32c
commit 4c84086932
1 changed files with 15 additions and 18 deletions

View File

@ -62,29 +62,27 @@ class aiuNNDataset(torch.utils.data.Dataset):
class Upscaler(nn.Module): class Upscaler(nn.Module):
""" """
Wraps the base model to perform upsampling and a final convolution. Transforms the base model's final feature map using a transposed convolution.
The base model produces a feature map of size 512x512. The base model produces a feature map of size 512x512.
We then upsample by a factor of 2 (to get 1024x1024) This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features
and use a convolution to map the hidden features to 3 output channels. to the output channels using a single ConvTranspose2d layer.
""" """
def __init__(self, base_model: AIIABase): def __init__(self, base_model: AIIABase):
super(Upscaler, self).__init__() super(Upscaler, self).__init__()
self.base_model = base_model self.base_model = base_model
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer.
self.final_conv = nn.Conv2d( self.last_transform = nn.ConvTranspose2d(
base_model.config.hidden_size, in_channels=base_model.config.hidden_size,
base_model.config.num_channels, out_channels=base_model.config.num_channels,
kernel_size=3, kernel_size=base_model.config.kernel_size,
padding=1 stride=2,
padding=1,
output_padding=1
) )
def forward(self, x): def forward(self, x):
# Get the feature maps from the base model (expected shape: [B, 512, 512, 512])
features = self.base_model(x) features = self.base_model(x)
# Upsample the features to match high resolution (1024x1024) return self.last_transform(features)
upsampled = self.upsample(features)
# Convert from hidden features to output channels
return self.final_conv(upsampled)
def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False): def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False):
# Load and concatenate datasets. # Load and concatenate datasets.
@ -134,7 +132,6 @@ def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=1
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
with autocast(device_type="cuda"): with autocast(device_type="cuda"):
if use_checkpoint: if use_checkpoint:
# Use checkpointing if requested.
low_res = batch['low_res'].to(device).requires_grad_() low_res = batch['low_res'].to(device).requires_grad_()
features = checkpoint(lambda x: model(x), low_res) features = checkpoint(lambda x: model(x), low_res)
else: else:
@ -178,13 +175,13 @@ def main():
ACCUMULATION_STEPS = 8 ACCUMULATION_STEPS = 8
USE_CHECKPOINT = False USE_CHECKPOINT = False
# Load the base model using the config values (hidden_size=512, num_channels=3, etc.) # Load the base model using the provided configuration (e.g., hidden_size=512, num_channels=3, etc.)
base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
# Wrap the base model in our Upscaler so that the output is upsampled to 1024x1024 # Wrap the base model with our modified Upscaler that transforms its last layer.
model = Upscaler(base_model) model = Upscaler(base_model)
print("Modified model architecture with upsampling wrapper:") print("Modified model architecture with transformed final layer:")
print(base_model.config) print(base_model.config)
finetune_model( finetune_model(