addeed tasks for both denosing and rotation
This commit is contained in:
parent
13cf1897ae
commit
b6b63851ca
145
src/pretrain.py
145
src/pretrain.py
|
@ -6,6 +6,7 @@ from aiia.model import AIIABase
|
|||
from aiia.data.DataLoader import AIIADataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||
# Read and merge datasets
|
||||
df1 = pd.read_parquet(data_path1).head(10000)
|
||||
|
@ -21,28 +22,49 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
model = AIIABase(config)
|
||||
|
||||
def safe_collate(batch):
|
||||
processed_batch = []
|
||||
denoise_batch = []
|
||||
rotate_batch = []
|
||||
|
||||
for sample in batch:
|
||||
try:
|
||||
noisy_img, target, task = sample
|
||||
processed_batch.append({
|
||||
'image': noisy_img,
|
||||
'target': target,
|
||||
'task': task
|
||||
})
|
||||
if task == 'denoise':
|
||||
denoise_batch.append({
|
||||
'image': noisy_img,
|
||||
'target': target,
|
||||
'task': task
|
||||
})
|
||||
else: # rotate task
|
||||
rotate_batch.append({
|
||||
'image': noisy_img,
|
||||
'target': target,
|
||||
'task': task
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Skipping sample due to error: {e}")
|
||||
continue
|
||||
|
||||
if not processed_batch:
|
||||
if not denoise_batch and not rotate_batch:
|
||||
return None
|
||||
|
||||
# Stack tensors for the batch
|
||||
images = torch.stack([x['image'] for x in processed_batch])
|
||||
targets = torch.stack([x['target'] for x in processed_batch]) # Stack targets
|
||||
tasks = [x['task'] for x in processed_batch]
|
||||
batch_data = {
|
||||
'denoise': None,
|
||||
'rotate': None
|
||||
}
|
||||
|
||||
return (images, targets, tasks)
|
||||
# Process denoise batch
|
||||
if denoise_batch:
|
||||
images = torch.stack([x['image'] for x in denoise_batch])
|
||||
targets = torch.stack([x['target'] for x in denoise_batch])
|
||||
batch_data['denoise'] = (images, targets)
|
||||
|
||||
# Process rotate batch
|
||||
if rotate_batch:
|
||||
images = torch.stack([x['image'] for x in rotate_batch])
|
||||
targets = torch.stack([x['target'] for x in rotate_batch])
|
||||
batch_data['rotate'] = (images, targets)
|
||||
|
||||
return batch_data
|
||||
|
||||
aiia_loader = AIIADataLoader(
|
||||
merged_df,
|
||||
|
@ -71,74 +93,81 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
# Training phase
|
||||
model.train()
|
||||
total_train_loss = 0.0
|
||||
batch_count = 0
|
||||
|
||||
for batch in tqdm(train_loader):
|
||||
if batch is None:
|
||||
for batch_data in tqdm(train_loader):
|
||||
if batch_data is None:
|
||||
continue
|
||||
|
||||
noisy_imgs, targets, tasks = batch
|
||||
batch_size = noisy_imgs.size(0)
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
batch_loss = 0
|
||||
|
||||
# Get model outputs and reshape if necessary
|
||||
outputs = model(noisy_imgs)
|
||||
# Handle denoise task
|
||||
if batch_data['denoise'] is not None:
|
||||
noisy_imgs, targets = batch_data['denoise']
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
task_losses = []
|
||||
for i, task in enumerate(tasks):
|
||||
if task == 'denoise':
|
||||
# Ensure output matches target shape for denoising
|
||||
output = outputs[i].view(3, 224, 224) # Reshape to match image dimensions
|
||||
target = targets[i]
|
||||
loss = criterion_denoise(output, target)
|
||||
else: # rotate task
|
||||
output = outputs[i].view(-1) # Flatten output for rotation prediction
|
||||
target = targets[i].long() # Convert target to long for classification
|
||||
loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
|
||||
task_losses.append(loss)
|
||||
outputs = model(noisy_imgs)
|
||||
loss = criterion_denoise(outputs, targets)
|
||||
batch_loss += loss
|
||||
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
batch_loss.backward()
|
||||
optimizer.step()
|
||||
# Handle rotate task
|
||||
if batch_data['rotate'] is not None:
|
||||
imgs, targets = batch_data['rotate']
|
||||
imgs = imgs.to(device)
|
||||
targets = targets.long().to(device)
|
||||
|
||||
total_train_loss += batch_loss.item()
|
||||
outputs = model(imgs)
|
||||
loss = criterion_rotate(outputs, targets)
|
||||
batch_loss += loss
|
||||
|
||||
avg_train_loss = total_train_loss / len(train_loader)
|
||||
if batch_loss > 0:
|
||||
batch_loss.backward()
|
||||
optimizer.step()
|
||||
total_train_loss += batch_loss.item()
|
||||
batch_count += 1
|
||||
|
||||
avg_train_loss = total_train_loss / max(batch_count, 1)
|
||||
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_batch_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
if batch is None:
|
||||
for batch_data in val_loader:
|
||||
if batch_data is None:
|
||||
continue
|
||||
|
||||
noisy_imgs, targets, tasks = batch
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
batch_loss = 0
|
||||
|
||||
outputs = model(noisy_imgs)
|
||||
# Handle denoise task
|
||||
if batch_data['denoise'] is not None:
|
||||
noisy_imgs, targets = batch_data['denoise']
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
task_losses = []
|
||||
for i, task in enumerate(tasks):
|
||||
if task == 'denoise':
|
||||
output = outputs[i].view(3, 224, 224)
|
||||
target = targets[i]
|
||||
loss = criterion_denoise(output, target)
|
||||
else:
|
||||
output = outputs[i].view(-1)
|
||||
target = targets[i].long()
|
||||
loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
|
||||
task_losses.append(loss)
|
||||
outputs = model(noisy_imgs)
|
||||
loss = criterion_denoise(outputs, targets)
|
||||
batch_loss += loss
|
||||
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
val_loss += batch_loss.item()
|
||||
# Handle rotate task
|
||||
if batch_data['rotate'] is not None:
|
||||
imgs, targets = batch_data['rotate']
|
||||
imgs = imgs.to(device)
|
||||
targets = targets.long().to(device)
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
outputs = model(imgs)
|
||||
loss = criterion_rotate(outputs, targets)
|
||||
batch_loss += loss
|
||||
|
||||
if batch_loss > 0:
|
||||
val_loss += batch_loss.item()
|
||||
val_batch_count += 1
|
||||
|
||||
avg_val_loss = val_loss / max(val_batch_count, 1)
|
||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
|
||||
if avg_val_loss < best_val_loss:
|
||||
|
|
Loading…
Reference in New Issue