added decoeder model

This commit is contained in:
Falko Victor Habel 2025-01-31 17:30:41 +01:00
parent 9187ebe012
commit be74658ceb
1 changed files with 50 additions and 26 deletions

View File

@ -99,6 +99,27 @@ class ImageDataset(Dataset):
if 'high_res_stream' in locals():
high_res_stream.close()
class SuperResolutionModel(AIIA):
def __init__(self, base_model):
super(SuperResolutionModel, self).__init__()
# Use base model as encoder
self.encoder = base_model
for param in self.encoder.parameters():
param.requires_grad = False # Freeze encoder layers
# Add decoder layers to reconstruct image
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1)
)
def forward(self, x):
features = self.encoder(x)
output = self.decoder(features)
return output
class FineTuner:
def __init__(self,
@ -296,36 +317,39 @@ class FineTuner:
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
def _validate_epoch(self):
"""
Validate model performance
Returns:
float: Average validation loss for the epoch
"""
self.model.eval()
val_loss = 0.0
def _train_epoch(self):
"""Train model for one epoch"""
self.model.train()
running_loss = 0.0
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Validation"):
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
for batch in tqdm(self.train_loader, desc="Training"):
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
try:
features = self.model(low_ress)
except Exception as e:
raise RuntimeError(f"Error during validation forward pass: {str(e)}")
# Forward pass
try:
outputs = self.model(low_ress) # Now outputs are images
print("Output shape:", outputs.shape)
print("High-res shape:", high_ress.shape)
except Exception as e:
raise RuntimeError(f"Error during forward pass: {str(e)}")
# Calculate same loss combination
l1_loss = self.criterion(features, high_ress) * 0.5
mse_loss = self.mse_criterion(features, high_ress) * 0.5
total_loss = l1_loss + mse_loss
# Calculate loss
l1_loss = self.criterion(outputs, high_ress) * 0.5
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
total_loss = l1_loss + mse_loss
val_loss += total_loss.item()
# Backward pass and optimize
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
self.current_val_loss = val_loss / len(self.val_loader)
self.val_losses.append(self.current_val_loss)
print(f"Validation Loss: {self.current_val_loss:.4f}")
return self.current_val_loss
running_loss += total_loss.item()
epoch_loss = running_loss / len(self.train_loader)
self.train_lossess.append(epoch_loss)
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
def train(self, num_epochs: int = 10):
"""
@ -369,7 +393,7 @@ class FineTuner:
if __name__ == "__main__":
# Load your model first
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"))
trainer = FineTuner(
model=model,