develop #4
|
@ -99,6 +99,27 @@ class ImageDataset(Dataset):
|
||||||
if 'high_res_stream' in locals():
|
if 'high_res_stream' in locals():
|
||||||
high_res_stream.close()
|
high_res_stream.close()
|
||||||
|
|
||||||
|
class SuperResolutionModel(AIIA):
|
||||||
|
def __init__(self, base_model):
|
||||||
|
super(SuperResolutionModel, self).__init__()
|
||||||
|
# Use base model as encoder
|
||||||
|
self.encoder = base_model
|
||||||
|
for param in self.encoder.parameters():
|
||||||
|
param.requires_grad = False # Freeze encoder layers
|
||||||
|
|
||||||
|
# Add decoder layers to reconstruct image
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
features = self.encoder(x)
|
||||||
|
output = self.decoder(features)
|
||||||
|
return output
|
||||||
|
|
||||||
class FineTuner:
|
class FineTuner:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -296,36 +317,39 @@ class FineTuner:
|
||||||
print(f"Train Loss: {epoch_loss:.4f}")
|
print(f"Train Loss: {epoch_loss:.4f}")
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
def _validate_epoch(self):
|
def _train_epoch(self):
|
||||||
"""
|
"""Train model for one epoch"""
|
||||||
Validate model performance
|
self.model.train()
|
||||||
Returns:
|
running_loss = 0.0
|
||||||
float: Average validation loss for the epoch
|
|
||||||
"""
|
|
||||||
self.model.eval()
|
|
||||||
val_loss = 0.0
|
|
||||||
|
|
||||||
with torch.no_grad():
|
for batch in tqdm(self.train_loader, desc="Training"):
|
||||||
for batch in tqdm(self.val_loader, desc="Validation"):
|
|
||||||
low_ress = batch['low_ress'].to(self.device)
|
low_ress = batch['low_ress'].to(self.device)
|
||||||
high_ress = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
try:
|
try:
|
||||||
features = self.model(low_ress)
|
outputs = self.model(low_ress) # Now outputs are images
|
||||||
|
print("Output shape:", outputs.shape)
|
||||||
|
print("High-res shape:", high_ress.shape)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error during validation forward pass: {str(e)}")
|
raise RuntimeError(f"Error during forward pass: {str(e)}")
|
||||||
|
|
||||||
# Calculate same loss combination
|
# Calculate loss
|
||||||
l1_loss = self.criterion(features, high_ress) * 0.5
|
l1_loss = self.criterion(outputs, high_ress) * 0.5
|
||||||
mse_loss = self.mse_criterion(features, high_ress) * 0.5
|
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
|
||||||
total_loss = l1_loss + mse_loss
|
total_loss = l1_loss + mse_loss
|
||||||
|
|
||||||
val_loss += total_loss.item()
|
# Backward pass and optimize
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
total_loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
self.current_val_loss = val_loss / len(self.val_loader)
|
running_loss += total_loss.item()
|
||||||
self.val_losses.append(self.current_val_loss)
|
|
||||||
print(f"Validation Loss: {self.current_val_loss:.4f}")
|
epoch_loss = running_loss / len(self.train_loader)
|
||||||
return self.current_val_loss
|
self.train_lossess.append(epoch_loss)
|
||||||
|
print(f"Train Loss: {epoch_loss:.4f}")
|
||||||
|
return epoch_loss
|
||||||
|
|
||||||
def train(self, num_epochs: int = 10):
|
def train(self, num_epochs: int = 10):
|
||||||
"""
|
"""
|
||||||
|
@ -369,7 +393,7 @@ class FineTuner:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Load your model first
|
# Load your model first
|
||||||
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
|
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"))
|
||||||
|
|
||||||
trainer = FineTuner(
|
trainer = FineTuner(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
Loading…
Reference in New Issue