finetune_class #1
|
@ -60,7 +60,33 @@ class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
'high_res': augmented_high['image']
|
'high_res': augmented_high['image']
|
||||||
}
|
}
|
||||||
|
|
||||||
def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False):
|
class Upscaler(nn.Module):
|
||||||
|
"""
|
||||||
|
Wraps the base model to perform upsampling and a final convolution.
|
||||||
|
The base model produces a feature map of size 512x512.
|
||||||
|
We then upsample by a factor of 2 (to get 1024x1024)
|
||||||
|
and use a convolution to map the hidden features to 3 output channels.
|
||||||
|
"""
|
||||||
|
def __init__(self, base_model: AIIABase):
|
||||||
|
super(Upscaler, self).__init__()
|
||||||
|
self.base_model = base_model
|
||||||
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
||||||
|
self.final_conv = nn.Conv2d(
|
||||||
|
base_model.config.hidden_size,
|
||||||
|
base_model.config.num_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Get the feature maps from the base model (expected shape: [B, 512, 512, 512])
|
||||||
|
features = self.base_model(x)
|
||||||
|
# Upsample the features to match high resolution (1024x1024)
|
||||||
|
upsampled = self.upsample(features)
|
||||||
|
# Convert from hidden features to output channels
|
||||||
|
return self.final_conv(upsampled)
|
||||||
|
|
||||||
|
def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=10, accumulation_steps=8, use_checkpoint=False):
|
||||||
# Load and concatenate datasets.
|
# Load and concatenate datasets.
|
||||||
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
||||||
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
|
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
|
||||||
|
@ -93,7 +119,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
|
optimizer = torch.optim.Adam(model.parameters(), lr=model.base_model.config.learning_rate)
|
||||||
scaler = GradScaler()
|
scaler = GradScaler()
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
|
@ -108,10 +134,11 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
|
||||||
high_res = batch['high_res'].to(device)
|
high_res = batch['high_res'].to(device)
|
||||||
with autocast(device_type="cuda"):
|
with autocast(device_type="cuda"):
|
||||||
if use_checkpoint:
|
if use_checkpoint:
|
||||||
outputs = checkpoint(lambda x: model(x), low_res)
|
# Use checkpointing if requested.
|
||||||
|
features = checkpoint(lambda x: model(x), low_res)
|
||||||
else:
|
else:
|
||||||
outputs = model(low_res)
|
features = model(low_res)
|
||||||
loss = criterion(outputs, high_res) / accumulation_steps
|
loss = criterion(features, high_res) / accumulation_steps
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
train_loss += loss.item() * accumulation_steps
|
train_loss += loss.item() * accumulation_steps
|
||||||
if i % accumulation_steps == 0:
|
if i % accumulation_steps == 0:
|
||||||
|
@ -142,7 +169,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
|
||||||
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
model.save("best_model")
|
model.base_model.save("best_model")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -150,19 +177,14 @@ def main():
|
||||||
ACCUMULATION_STEPS = 8
|
ACCUMULATION_STEPS = 8
|
||||||
USE_CHECKPOINT = True
|
USE_CHECKPOINT = True
|
||||||
|
|
||||||
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
|
# Load the base model using the config values (hidden_size=512, num_channels=3, etc.)
|
||||||
|
base_model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
|
||||||
|
|
||||||
# Upsample output if a 'chunked_' attribute exists, ensuring spatial dimensions match the high resolution images.
|
# Wrap the base model in our Upscaler so that the output is upsampled to 1024x1024
|
||||||
if hasattr(model, 'chunked_'):
|
model = Upscaler(base_model)
|
||||||
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
|
|
||||||
|
|
||||||
# Append a final convolutional layer using values from model.config.
|
print("Modified model architecture with upsampling wrapper:")
|
||||||
# This converts the hidden feature maps (512 channels) to the 3 required output channels.
|
print(base_model.config)
|
||||||
final_conv = nn.Conv2d(model.config.hidden_size, model.config.num_channels, kernel_size=3, padding=1)
|
|
||||||
model.add_module('final_layer', final_conv)
|
|
||||||
|
|
||||||
print("Modified model architecture:")
|
|
||||||
print(model.config)
|
|
||||||
|
|
||||||
finetune_model(
|
finetune_model(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
Loading…
Reference in New Issue