develop #4
|
@ -100,7 +100,6 @@ class ImageDataset(Dataset):
|
||||||
high_res_stream.close()
|
high_res_stream.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FineTuner:
|
class FineTuner:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: AIIA,
|
model: AIIA,
|
||||||
|
@ -129,7 +128,7 @@ class FineTuner:
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.train_ratio = train_ratio
|
self.train_ratio = train_ratio
|
||||||
self.model = model
|
self.model = model
|
||||||
self.ouptut_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
# Create output directory if it doesn't exist
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
@ -153,12 +152,22 @@ class FineTuner:
|
||||||
self._initialize_training()
|
self._initialize_training()
|
||||||
|
|
||||||
def _freeze_layers(self):
|
def _freeze_layers(self):
|
||||||
"""Freeze all layers except the upsample layer"""
|
"""
|
||||||
|
Freeze all layers except those that are part of the decoder or upsampling
|
||||||
|
We'll assume the last few layers are responsible for upsampling/reconstruction
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Try to freeze layers based on their names
|
# Try to identify encoder layers and freeze them
|
||||||
for name, param in self.model.named_parameters():
|
for name, param in self.model.named_parameters():
|
||||||
if 'upsample' not in name:
|
if 'encoder' in name:
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Unfreeze certain layers (example: last 3 decoder layers)
|
||||||
|
# Modify this based on your actual model architecture
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
if 'decoder' in name and 'block4' in name or 'block5' in name:
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Couldn't freeze layers - {str(e)}")
|
print(f"Warning: Couldn't freeze layers - {str(e)}")
|
||||||
pass
|
pass
|
||||||
|
@ -221,38 +230,9 @@ class FineTuner:
|
||||||
"""
|
"""
|
||||||
Helper method to initialize training parameters
|
Helper method to initialize training parameters
|
||||||
"""
|
"""
|
||||||
# Freeze all layers except upsample layer
|
# Freeze layers except those we want to finetune
|
||||||
self._freeze_layers()
|
self._freeze_layers()
|
||||||
|
|
||||||
# Add upscaling layer if not already present
|
|
||||||
if not hasattr(self.model, 'upsample'):
|
|
||||||
# Try to get existing configuration or set defaults
|
|
||||||
try:
|
|
||||||
hidden_size = self.model.config.hidden_size
|
|
||||||
kernel_size = 3 # Use odd-sized kernel for better performance
|
|
||||||
except AttributeError:
|
|
||||||
# Fallback values if config isn't available
|
|
||||||
hidden_size = 512
|
|
||||||
kernel_size = 3
|
|
||||||
|
|
||||||
self.model.upsample = nn.Sequential(
|
|
||||||
nn.ConvTranspose2d(hidden_size,
|
|
||||||
hidden_size//2,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
output_padding=1),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.ConvTranspose2d(hidden_size//2,
|
|
||||||
3,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
output_padding=1)
|
|
||||||
)
|
|
||||||
self.model.config.upsample_hidden_size = hidden_size
|
|
||||||
self.model.config.upsample_kernel_size = kernel_size
|
|
||||||
|
|
||||||
# Initialize optimizer and scheduler
|
# Initialize optimizer and scheduler
|
||||||
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
|
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
|
||||||
|
|
||||||
|
@ -291,18 +271,15 @@ class FineTuner:
|
||||||
low_ress = batch['low_ress'].to(self.device)
|
low_ress = batch['low_ress'].to(self.device)
|
||||||
high_ress = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
|
# Forward pass (we'll use the model's existing architecture without adding layers)
|
||||||
try:
|
try:
|
||||||
# Try using CNN layer if available
|
features = self.model(low_ress)
|
||||||
features = self.model.cnn(low_ress)
|
except Exception as e:
|
||||||
except AttributeError:
|
raise RuntimeError(f"Error during forward pass: {str(e)}")
|
||||||
# Fallback to extract_features method
|
|
||||||
features = self.model.extract_features(low_ress)
|
|
||||||
|
|
||||||
outputs = self.model.upsample(features)
|
|
||||||
|
|
||||||
# Calculate loss with different scaling for L1 and MSE components
|
# Calculate loss with different scaling for L1 and MSE components
|
||||||
l1_loss = self.criterion(outputs, high_ress) * 0.5
|
l1_loss = self.criterion(features, high_ress) * 0.5
|
||||||
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
|
mse_loss = self.mse_criterion(features, high_ress) * 0.5
|
||||||
total_loss = l1_loss + mse_loss
|
total_loss = l1_loss + mse_loss
|
||||||
|
|
||||||
# Backward pass and optimize
|
# Backward pass and optimize
|
||||||
|
@ -332,15 +309,13 @@ class FineTuner:
|
||||||
high_ress = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
features = self.model.cnn(low_ress)
|
features = self.model(low_ress)
|
||||||
except AttributeError:
|
except Exception as e:
|
||||||
features = self.model.extract_features(low_ress)
|
raise RuntimeError(f"Error during validation forward pass: {str(e)}")
|
||||||
|
|
||||||
outputs = self.model.upsample(features)
|
|
||||||
|
|
||||||
# Calculate same loss combination
|
# Calculate same loss combination
|
||||||
l1_loss = self.criterion(outputs, high_ress) * 0.5
|
l1_loss = self.criterion(features, high_ress) * 0.5
|
||||||
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
|
mse_loss = self.mse_criterion(features, high_ress) * 0.5
|
||||||
total_loss = l1_loss + mse_loss
|
total_loss = l1_loss + mse_loss
|
||||||
|
|
||||||
val_loss += total_loss.item()
|
val_loss += total_loss.item()
|
||||||
|
@ -380,12 +355,12 @@ class FineTuner:
|
||||||
if self.current_val_loss < self.best_val_loss:
|
if self.current_val_loss < self.best_val_loss:
|
||||||
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
|
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
|
||||||
self.best_val_loss = self.current_val_loss
|
self.best_val_loss = self.current_val_loss
|
||||||
model_save_path = os.path.join(self.ouptut_dir, "aiuNN-optimized")
|
model_save_path = os.path.join(self.output_dir, "aiuNN-optimized")
|
||||||
self.model.save(model_save_path)
|
self.model.save(model_save_path)
|
||||||
print(f"Model saved to: {model_save_path}")
|
print(f"Model saved to: {model_save_path}")
|
||||||
|
|
||||||
# After training, save the final model
|
# After training, save the final model
|
||||||
final_model_path = os.path.join(self.ouptut_dir, "aiuNN-final")
|
final_model_path = os.path.join(self.output_dir, "aiuNN-final")
|
||||||
self.model.save(final_model_path)
|
self.model.save(final_model_path)
|
||||||
print(f"\nFinal model saved to: {final_model_path}")
|
print(f"\nFinal model saved to: {final_model_path}")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue