added decoeder model
This commit is contained in:
parent
9187ebe012
commit
be74658ceb
|
@ -99,6 +99,27 @@ class ImageDataset(Dataset):
|
|||
if 'high_res_stream' in locals():
|
||||
high_res_stream.close()
|
||||
|
||||
class SuperResolutionModel(AIIA):
|
||||
def __init__(self, base_model):
|
||||
super(SuperResolutionModel, self).__init__()
|
||||
# Use base model as encoder
|
||||
self.encoder = base_model
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False # Freeze encoder layers
|
||||
|
||||
# Add decoder layers to reconstruct image
|
||||
self.decoder = nn.Sequential(
|
||||
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
features = self.encoder(x)
|
||||
output = self.decoder(features)
|
||||
return output
|
||||
|
||||
class FineTuner:
|
||||
def __init__(self,
|
||||
|
@ -296,36 +317,39 @@ class FineTuner:
|
|||
print(f"Train Loss: {epoch_loss:.4f}")
|
||||
return epoch_loss
|
||||
|
||||
def _validate_epoch(self):
|
||||
"""
|
||||
Validate model performance
|
||||
Returns:
|
||||
float: Average validation loss for the epoch
|
||||
"""
|
||||
self.model.eval()
|
||||
val_loss = 0.0
|
||||
def _train_epoch(self):
|
||||
"""Train model for one epoch"""
|
||||
self.model.train()
|
||||
running_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(self.val_loader, desc="Validation"):
|
||||
low_ress = batch['low_ress'].to(self.device)
|
||||
high_ress = batch['high_ress'].to(self.device)
|
||||
for batch in tqdm(self.train_loader, desc="Training"):
|
||||
low_ress = batch['low_ress'].to(self.device)
|
||||
high_ress = batch['high_ress'].to(self.device)
|
||||
|
||||
try:
|
||||
features = self.model(low_ress)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during validation forward pass: {str(e)}")
|
||||
# Forward pass
|
||||
try:
|
||||
outputs = self.model(low_ress) # Now outputs are images
|
||||
print("Output shape:", outputs.shape)
|
||||
print("High-res shape:", high_ress.shape)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during forward pass: {str(e)}")
|
||||
|
||||
# Calculate same loss combination
|
||||
l1_loss = self.criterion(features, high_ress) * 0.5
|
||||
mse_loss = self.mse_criterion(features, high_ress) * 0.5
|
||||
total_loss = l1_loss + mse_loss
|
||||
# Calculate loss
|
||||
l1_loss = self.criterion(outputs, high_ress) * 0.5
|
||||
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
|
||||
total_loss = l1_loss + mse_loss
|
||||
|
||||
val_loss += total_loss.item()
|
||||
# Backward pass and optimize
|
||||
self.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
self.current_val_loss = val_loss / len(self.val_loader)
|
||||
self.val_losses.append(self.current_val_loss)
|
||||
print(f"Validation Loss: {self.current_val_loss:.4f}")
|
||||
return self.current_val_loss
|
||||
running_loss += total_loss.item()
|
||||
|
||||
epoch_loss = running_loss / len(self.train_loader)
|
||||
self.train_lossess.append(epoch_loss)
|
||||
print(f"Train Loss: {epoch_loss:.4f}")
|
||||
return epoch_loss
|
||||
|
||||
def train(self, num_epochs: int = 10):
|
||||
"""
|
||||
|
@ -369,7 +393,7 @@ class FineTuner:
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Load your model first
|
||||
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
|
||||
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"))
|
||||
|
||||
trainer = FineTuner(
|
||||
model=model,
|
||||
|
|
Loading…
Reference in New Issue