handle both pretraining methods
This commit is contained in:
parent
91a568731f
commit
13cf1897ae
|
@ -6,7 +6,6 @@ from aiia.model import AIIABase
|
||||||
from aiia.data.DataLoader import AIIADataLoader
|
from aiia.data.DataLoader import AIIADataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
# Read and merge datasets
|
# Read and merge datasets
|
||||||
df1 = pd.read_parquet(data_path1).head(10000)
|
df1 = pd.read_parquet(data_path1).head(10000)
|
||||||
|
@ -21,7 +20,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
# Initialize model and data loader
|
# Initialize model and data loader
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
|
|
||||||
# Define a custom collate function to handle preprocessing and skip bad samples
|
|
||||||
def safe_collate(batch):
|
def safe_collate(batch):
|
||||||
processed_batch = []
|
processed_batch = []
|
||||||
for sample in batch:
|
for sample in batch:
|
||||||
|
@ -37,11 +35,11 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not processed_batch:
|
if not processed_batch:
|
||||||
return None # Skip batch if all samples are invalid
|
return None
|
||||||
|
|
||||||
# Stack tensors for the batch
|
# Stack tensors for the batch
|
||||||
images = torch.stack([x['image'] for x in processed_batch])
|
images = torch.stack([x['image'] for x in processed_batch])
|
||||||
targets = [x['target'] for x in processed_batch] # Don't stack targets yet
|
targets = torch.stack([x['target'] for x in processed_batch]) # Stack targets
|
||||||
tasks = [x['task'] for x in processed_batch]
|
tasks = [x['task'] for x in processed_batch]
|
||||||
|
|
||||||
return (images, targets, tasks)
|
return (images, targets, tasks)
|
||||||
|
@ -57,7 +55,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
train_loader = aiia_loader.train_loader
|
train_loader = aiia_loader.train_loader
|
||||||
val_loader = aiia_loader.val_loader
|
val_loader = aiia_loader.val_loader
|
||||||
|
|
||||||
# Define loss functions and optimizer
|
|
||||||
criterion_denoise = nn.MSELoss()
|
criterion_denoise = nn.MSELoss()
|
||||||
criterion_rotate = nn.CrossEntropyLoss()
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||||
|
@ -77,25 +74,29 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
for batch in tqdm(train_loader):
|
for batch in tqdm(train_loader):
|
||||||
if batch is None:
|
if batch is None:
|
||||||
continue # Skip empty batches
|
continue
|
||||||
|
|
||||||
noisy_imgs, targets, tasks = batch
|
noisy_imgs, targets, tasks = batch
|
||||||
|
batch_size = noisy_imgs.size(0)
|
||||||
noisy_imgs = noisy_imgs.to(device)
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
|
targets = targets.to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Get model outputs and reshape if necessary
|
||||||
outputs = model(noisy_imgs)
|
outputs = model(noisy_imgs)
|
||||||
task_losses = []
|
|
||||||
|
|
||||||
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
|
task_losses = []
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
target = target.to(device) # Move target to device
|
# Ensure output matches target shape for denoising
|
||||||
# Ensure output and target have same shape
|
output = outputs[i].view(3, 224, 224) # Reshape to match image dimensions
|
||||||
|
target = targets[i]
|
||||||
loss = criterion_denoise(output, target)
|
loss = criterion_denoise(output, target)
|
||||||
else: # rotate task
|
else: # rotate task
|
||||||
target = target.to(device) # Move target to device
|
output = outputs[i].view(-1) # Flatten output for rotation prediction
|
||||||
# For rotation task, output should be [batch_size, num_classes]
|
target = targets[i].long() # Convert target to long for classification
|
||||||
loss = criterion_rotate(output.view(1, -1), target.view(-1))
|
loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
|
||||||
task_losses.append(loss)
|
task_losses.append(loss)
|
||||||
|
|
||||||
batch_loss = sum(task_losses) / len(task_losses)
|
batch_loss = sum(task_losses) / len(task_losses)
|
||||||
|
@ -118,17 +119,20 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
noisy_imgs, targets, tasks = batch
|
noisy_imgs, targets, tasks = batch
|
||||||
noisy_imgs = noisy_imgs.to(device)
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
|
targets = targets.to(device)
|
||||||
|
|
||||||
outputs = model(noisy_imgs)
|
outputs = model(noisy_imgs)
|
||||||
task_losses = []
|
|
||||||
|
|
||||||
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
|
task_losses = []
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
target = target.to(device)
|
output = outputs[i].view(3, 224, 224)
|
||||||
|
target = targets[i]
|
||||||
loss = criterion_denoise(output, target)
|
loss = criterion_denoise(output, target)
|
||||||
else:
|
else:
|
||||||
target = target.to(device)
|
output = outputs[i].view(-1)
|
||||||
loss = criterion_rotate(output.view(1, -1), target.view(-1))
|
target = targets[i].long()
|
||||||
|
loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
|
||||||
task_losses.append(loss)
|
task_losses.append(loss)
|
||||||
|
|
||||||
batch_loss = sum(task_losses) / len(task_losses)
|
batch_loss = sum(task_losses) / len(task_losses)
|
||||||
|
|
Loading…
Reference in New Issue