overall improvement
This commit is contained in:
parent
fe4d6b5b22
commit
58baf0ad3c
|
@ -6,6 +6,18 @@ from aiia.model import AIIABase
|
||||||
from aiia.data.DataLoader import AIIADataLoader
|
from aiia.data.DataLoader import AIIADataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class ProjectionHead(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
|
||||||
|
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
|
||||||
|
|
||||||
|
def forward(self, x, task='denoise'):
|
||||||
|
if task == 'denoise':
|
||||||
|
return self.conv_denoise(x)
|
||||||
|
else:
|
||||||
|
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||||
|
|
||||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
# Read and merge datasets
|
# Read and merge datasets
|
||||||
df1 = pd.read_parquet(data_path1).head(10000)
|
df1 = pd.read_parquet(data_path1).head(10000)
|
||||||
|
@ -17,9 +29,14 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
model_name="AIIA-Base-512x20k",
|
model_name="AIIA-Base-512x20k",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize model and data loader
|
# Initialize model and projection head
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
|
projection_head = ProjectionHead()
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model.to(device)
|
||||||
|
projection_head.to(device)
|
||||||
|
|
||||||
def safe_collate(batch):
|
def safe_collate(batch):
|
||||||
denoise_batch = []
|
denoise_batch = []
|
||||||
rotate_batch = []
|
rotate_batch = []
|
||||||
|
@ -51,13 +68,11 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
'rotate': None
|
'rotate': None
|
||||||
}
|
}
|
||||||
|
|
||||||
# Process denoise batch
|
|
||||||
if denoise_batch:
|
if denoise_batch:
|
||||||
images = torch.stack([x['image'] for x in denoise_batch])
|
images = torch.stack([x['image'] for x in denoise_batch])
|
||||||
targets = torch.stack([x['target'] for x in denoise_batch])
|
targets = torch.stack([x['target'] for x in denoise_batch])
|
||||||
batch_data['denoise'] = (images, targets)
|
batch_data['denoise'] = (images, targets)
|
||||||
|
|
||||||
# Process rotate batch
|
|
||||||
if rotate_batch:
|
if rotate_batch:
|
||||||
images = torch.stack([x['image'] for x in rotate_batch])
|
images = torch.stack([x['image'] for x in rotate_batch])
|
||||||
targets = torch.stack([x['target'] for x in rotate_batch])
|
targets = torch.stack([x['target'] for x in rotate_batch])
|
||||||
|
@ -78,10 +93,12 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
criterion_denoise = nn.MSELoss()
|
criterion_denoise = nn.MSELoss()
|
||||||
criterion_rotate = nn.CrossEntropyLoss()
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
|
||||||
|
# Update optimizer to include projection head parameters
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
optimizer = torch.optim.AdamW(
|
||||||
model.to(device)
|
list(model.parameters()) + list(projection_head.parameters()),
|
||||||
|
lr=config.learning_rate
|
||||||
|
)
|
||||||
|
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
|
@ -91,6 +108,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
# Training phase
|
# Training phase
|
||||||
model.train()
|
model.train()
|
||||||
|
projection_head.train()
|
||||||
total_train_loss = 0.0
|
total_train_loss = 0.0
|
||||||
batch_count = 0
|
batch_count = 0
|
||||||
|
|
||||||
|
@ -107,18 +125,16 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
noisy_imgs = noisy_imgs.to(device)
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
# Print shapes for debugging
|
# Get features from base model
|
||||||
|
features = model(noisy_imgs)
|
||||||
|
# Project features back to image space
|
||||||
|
outputs = projection_head(features, task='denoise')
|
||||||
|
|
||||||
print(f"\nDenoising task shapes:")
|
print(f"\nDenoising task shapes:")
|
||||||
print(f"Input shape: {noisy_imgs.shape}")
|
print(f"Input shape: {noisy_imgs.shape}")
|
||||||
print(f"Target shape: {targets.shape}")
|
print(f"Target shape: {targets.shape}")
|
||||||
|
print(f"Features shape: {features.shape}")
|
||||||
outputs = model(noisy_imgs)
|
print(f"Output shape: {outputs.shape}")
|
||||||
print(f"Raw output shape: {outputs.shape}")
|
|
||||||
|
|
||||||
# Reshape output to match target dimensions
|
|
||||||
batch_size = targets.size(0)
|
|
||||||
outputs = outputs.view(batch_size, 3, 224, 224)
|
|
||||||
print(f"Reshaped output shape: {outputs.shape}")
|
|
||||||
|
|
||||||
loss = criterion_denoise(outputs, targets)
|
loss = criterion_denoise(outputs, targets)
|
||||||
batch_loss += loss
|
batch_loss += loss
|
||||||
|
@ -129,17 +145,16 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
imgs = imgs.to(device)
|
imgs = imgs.to(device)
|
||||||
targets = targets.long().to(device)
|
targets = targets.long().to(device)
|
||||||
|
|
||||||
# Print shapes for debugging
|
# Get features from base model
|
||||||
|
features = model(imgs)
|
||||||
|
# Project features to rotation predictions
|
||||||
|
outputs = projection_head(features, task='rotate')
|
||||||
|
|
||||||
print(f"\nRotation task shapes:")
|
print(f"\nRotation task shapes:")
|
||||||
print(f"Input shape: {imgs.shape}")
|
print(f"Input shape: {imgs.shape}")
|
||||||
print(f"Target shape: {targets.shape}")
|
print(f"Target shape: {targets.shape}")
|
||||||
|
print(f"Features shape: {features.shape}")
|
||||||
outputs = model(imgs)
|
print(f"Output shape: {outputs.shape}")
|
||||||
print(f"Raw output shape: {outputs.shape}")
|
|
||||||
|
|
||||||
# Reshape output for rotation classification
|
|
||||||
outputs = outputs.view(targets.size(0), -1) # Flatten to [batch_size, features]
|
|
||||||
print(f"Reshaped output shape: {outputs.shape}")
|
|
||||||
|
|
||||||
loss = criterion_rotate(outputs, targets)
|
loss = criterion_rotate(outputs, targets)
|
||||||
batch_loss += loss
|
batch_loss += loss
|
||||||
|
@ -155,6 +170,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
# Validation phase
|
# Validation phase
|
||||||
model.eval()
|
model.eval()
|
||||||
|
projection_head.eval()
|
||||||
val_loss = 0.0
|
val_loss = 0.0
|
||||||
val_batch_count = 0
|
val_batch_count = 0
|
||||||
|
|
||||||
|
@ -165,26 +181,23 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
batch_loss = 0
|
batch_loss = 0
|
||||||
|
|
||||||
# Handle denoise task
|
|
||||||
if batch_data['denoise'] is not None:
|
if batch_data['denoise'] is not None:
|
||||||
noisy_imgs, targets = batch_data['denoise']
|
noisy_imgs, targets = batch_data['denoise']
|
||||||
noisy_imgs = noisy_imgs.to(device)
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
outputs = model(noisy_imgs)
|
features = model(noisy_imgs)
|
||||||
batch_size = targets.size(0)
|
outputs = projection_head(features, task='denoise')
|
||||||
outputs = outputs.view(batch_size, 3, 224, 224)
|
|
||||||
loss = criterion_denoise(outputs, targets)
|
loss = criterion_denoise(outputs, targets)
|
||||||
batch_loss += loss
|
batch_loss += loss
|
||||||
|
|
||||||
# Handle rotate task
|
|
||||||
if batch_data['rotate'] is not None:
|
if batch_data['rotate'] is not None:
|
||||||
imgs, targets = batch_data['rotate']
|
imgs, targets = batch_data['rotate']
|
||||||
imgs = imgs.to(device)
|
imgs = imgs.to(device)
|
||||||
targets = targets.long().to(device)
|
targets = targets.long().to(device)
|
||||||
|
|
||||||
outputs = model(imgs)
|
features = model(imgs)
|
||||||
outputs = outputs.view(targets.size(0), -1)
|
outputs = projection_head(features, task='rotate')
|
||||||
loss = criterion_rotate(outputs, targets)
|
loss = criterion_rotate(outputs, targets)
|
||||||
batch_loss += loss
|
batch_loss += loss
|
||||||
|
|
||||||
|
@ -197,7 +210,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
model.save("BASEv0.1")
|
# Save both model and projection head
|
||||||
|
model.save("AIIA-base-512")
|
||||||
print("Best model saved!")
|
print("Best model saved!")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue