develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 50 additions and 26 deletions
Showing only changes of commit be74658ceb - Show all commits

View File

@ -99,6 +99,27 @@ class ImageDataset(Dataset):
if 'high_res_stream' in locals(): if 'high_res_stream' in locals():
high_res_stream.close() high_res_stream.close()
class SuperResolutionModel(AIIA):
def __init__(self, base_model):
super(SuperResolutionModel, self).__init__()
# Use base model as encoder
self.encoder = base_model
for param in self.encoder.parameters():
param.requires_grad = False # Freeze encoder layers
# Add decoder layers to reconstruct image
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1)
)
def forward(self, x):
features = self.encoder(x)
output = self.decoder(features)
return output
class FineTuner: class FineTuner:
def __init__(self, def __init__(self,
@ -296,36 +317,39 @@ class FineTuner:
print(f"Train Loss: {epoch_loss:.4f}") print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss return epoch_loss
def _validate_epoch(self): def _train_epoch(self):
""" """Train model for one epoch"""
Validate model performance self.model.train()
Returns: running_loss = 0.0
float: Average validation loss for the epoch
"""
self.model.eval()
val_loss = 0.0
with torch.no_grad(): for batch in tqdm(self.train_loader, desc="Training"):
for batch in tqdm(self.val_loader, desc="Validation"):
low_ress = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
# Forward pass
try: try:
features = self.model(low_ress) outputs = self.model(low_ress) # Now outputs are images
print("Output shape:", outputs.shape)
print("High-res shape:", high_ress.shape)
except Exception as e: except Exception as e:
raise RuntimeError(f"Error during validation forward pass: {str(e)}") raise RuntimeError(f"Error during forward pass: {str(e)}")
# Calculate same loss combination # Calculate loss
l1_loss = self.criterion(features, high_ress) * 0.5 l1_loss = self.criterion(outputs, high_ress) * 0.5
mse_loss = self.mse_criterion(features, high_ress) * 0.5 mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
total_loss = l1_loss + mse_loss total_loss = l1_loss + mse_loss
val_loss += total_loss.item() # Backward pass and optimize
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
self.current_val_loss = val_loss / len(self.val_loader) running_loss += total_loss.item()
self.val_losses.append(self.current_val_loss)
print(f"Validation Loss: {self.current_val_loss:.4f}") epoch_loss = running_loss / len(self.train_loader)
return self.current_val_loss self.train_lossess.append(epoch_loss)
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
def train(self, num_epochs: int = 10): def train(self, num_epochs: int = 10):
""" """
@ -369,7 +393,7 @@ class FineTuner:
if __name__ == "__main__": if __name__ == "__main__":
# Load your model first # Load your model first
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"))
trainer = FineTuner( trainer = FineTuner(
model=model, model=model,