finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 39 additions and 17 deletions
Showing only changes of commit 75be3291d3 - Show all commits

View File

@ -60,7 +60,33 @@ class aiuNNDataset(torch.utils.data.Dataset):
'high_res': augmented_high['image']
}
def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False):
class Upscaler(nn.Module):
"""
Wraps the base model to perform upsampling and a final convolution.
The base model produces a feature map of size 512x512.
We then upsample by a factor of 2 (to get 1024x1024)
and use a convolution to map the hidden features to 3 output channels.
"""
def __init__(self, base_model: AIIABase):
super(Upscaler, self).__init__()
self.base_model = base_model
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.final_conv = nn.Conv2d(
base_model.config.hidden_size,
base_model.config.num_channels,
kernel_size=3,
padding=1
)
def forward(self, x):
# Get the feature maps from the base model (expected shape: [B, 512, 512, 512])
features = self.base_model(x)
# Upsample the features to match high resolution (1024x1024)
upsampled = self.upsample(features)
# Convert from hidden features to output channels
return self.final_conv(upsampled)
def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False):
# Load and concatenate datasets.
loaded_datasets = [aiuNNDataset(d) for d in datasets]
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
@ -93,7 +119,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
model = model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=model.base_model.config.learning_rate)
scaler = GradScaler()
best_val_loss = float('inf')
@ -108,10 +134,11 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
high_res = batch['high_res'].to(device)
with autocast(device_type="cuda"):
if use_checkpoint:
outputs = checkpoint(lambda x: model(x), low_res)
# Use checkpointing if requested.
features = checkpoint(lambda x: model(x), low_res)
else:
outputs = model(low_res)
loss = criterion(outputs, high_res) / accumulation_steps
features = model(low_res)
loss = criterion(features, high_res) / accumulation_steps
scaler.scale(loss).backward()
train_loss += loss.item() * accumulation_steps
if i % accumulation_steps == 0:
@ -142,7 +169,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("best_model")
model.base_model.save("best_model")
return model
def main():
@ -150,19 +177,14 @@ def main():
ACCUMULATION_STEPS = 8
USE_CHECKPOINT = True
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
# Load the base model using the config values (hidden_size=512, num_channels=3, etc.)
base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
# Upsample output if a 'chunked_' attribute exists, ensuring spatial dimensions match the high resolution images.
if hasattr(model, 'chunked_'):
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
# Wrap the base model in our Upscaler so that the output is upsampled to 1024x1024
model = Upscaler(base_model)
# Append a final convolutional layer using values from model.config.
# This converts the hidden feature maps (512 channels) to the 3 required output channels.
final_conv = nn.Conv2d(model.config.hidden_size, model.config.num_channels, kernel_size=3, padding=1)
model.add_module('final_layer', final_conv)
print("Modified model architecture:")
print(model.config)
print("Modified model architecture with upsampling wrapper:")
print(base_model.config)
finetune_model(
model=model,