handle both pretraining methods
This commit is contained in:
parent
91a568731f
commit
13cf1897ae
|
@ -6,7 +6,6 @@ from aiia.model import AIIABase
|
|||
from aiia.data.DataLoader import AIIADataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||
# Read and merge datasets
|
||||
df1 = pd.read_parquet(data_path1).head(10000)
|
||||
|
@ -21,7 +20,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
# Initialize model and data loader
|
||||
model = AIIABase(config)
|
||||
|
||||
# Define a custom collate function to handle preprocessing and skip bad samples
|
||||
def safe_collate(batch):
|
||||
processed_batch = []
|
||||
for sample in batch:
|
||||
|
@ -37,11 +35,11 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
continue
|
||||
|
||||
if not processed_batch:
|
||||
return None # Skip batch if all samples are invalid
|
||||
return None
|
||||
|
||||
# Stack tensors for the batch
|
||||
images = torch.stack([x['image'] for x in processed_batch])
|
||||
targets = [x['target'] for x in processed_batch] # Don't stack targets yet
|
||||
targets = torch.stack([x['target'] for x in processed_batch]) # Stack targets
|
||||
tasks = [x['task'] for x in processed_batch]
|
||||
|
||||
return (images, targets, tasks)
|
||||
|
@ -57,7 +55,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
train_loader = aiia_loader.train_loader
|
||||
val_loader = aiia_loader.val_loader
|
||||
|
||||
# Define loss functions and optimizer
|
||||
criterion_denoise = nn.MSELoss()
|
||||
criterion_rotate = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||
|
@ -77,25 +74,29 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
|
||||
for batch in tqdm(train_loader):
|
||||
if batch is None:
|
||||
continue # Skip empty batches
|
||||
continue
|
||||
|
||||
noisy_imgs, targets, tasks = batch
|
||||
batch_size = noisy_imgs.size(0)
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Get model outputs and reshape if necessary
|
||||
outputs = model(noisy_imgs)
|
||||
task_losses = []
|
||||
|
||||
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
|
||||
task_losses = []
|
||||
for i, task in enumerate(tasks):
|
||||
if task == 'denoise':
|
||||
target = target.to(device) # Move target to device
|
||||
# Ensure output and target have same shape
|
||||
# Ensure output matches target shape for denoising
|
||||
output = outputs[i].view(3, 224, 224) # Reshape to match image dimensions
|
||||
target = targets[i]
|
||||
loss = criterion_denoise(output, target)
|
||||
else: # rotate task
|
||||
target = target.to(device) # Move target to device
|
||||
# For rotation task, output should be [batch_size, num_classes]
|
||||
loss = criterion_rotate(output.view(1, -1), target.view(-1))
|
||||
output = outputs[i].view(-1) # Flatten output for rotation prediction
|
||||
target = targets[i].long() # Convert target to long for classification
|
||||
loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
|
||||
task_losses.append(loss)
|
||||
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
|
@ -118,17 +119,20 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
|
||||
noisy_imgs, targets, tasks = batch
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
outputs = model(noisy_imgs)
|
||||
task_losses = []
|
||||
|
||||
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
|
||||
task_losses = []
|
||||
for i, task in enumerate(tasks):
|
||||
if task == 'denoise':
|
||||
target = target.to(device)
|
||||
output = outputs[i].view(3, 224, 224)
|
||||
target = targets[i]
|
||||
loss = criterion_denoise(output, target)
|
||||
else:
|
||||
target = target.to(device)
|
||||
loss = criterion_rotate(output.view(1, -1), target.view(-1))
|
||||
output = outputs[i].view(-1)
|
||||
target = targets[i].long()
|
||||
loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
|
||||
task_losses.append(loss)
|
||||
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
|
|
Loading…
Reference in New Issue