corrected viewing and some prints
This commit is contained in:
parent
b6b63851ca
commit
fe4d6b5b22
|
@ -6,7 +6,6 @@ from aiia.model import AIIABase
|
||||||
from aiia.data.DataLoader import AIIADataLoader
|
from aiia.data.DataLoader import AIIADataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
# Read and merge datasets
|
# Read and merge datasets
|
||||||
df1 = pd.read_parquet(data_path1).head(10000)
|
df1 = pd.read_parquet(data_path1).head(10000)
|
||||||
|
@ -108,7 +107,19 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
noisy_imgs = noisy_imgs.to(device)
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
|
# Print shapes for debugging
|
||||||
|
print(f"\nDenoising task shapes:")
|
||||||
|
print(f"Input shape: {noisy_imgs.shape}")
|
||||||
|
print(f"Target shape: {targets.shape}")
|
||||||
|
|
||||||
outputs = model(noisy_imgs)
|
outputs = model(noisy_imgs)
|
||||||
|
print(f"Raw output shape: {outputs.shape}")
|
||||||
|
|
||||||
|
# Reshape output to match target dimensions
|
||||||
|
batch_size = targets.size(0)
|
||||||
|
outputs = outputs.view(batch_size, 3, 224, 224)
|
||||||
|
print(f"Reshaped output shape: {outputs.shape}")
|
||||||
|
|
||||||
loss = criterion_denoise(outputs, targets)
|
loss = criterion_denoise(outputs, targets)
|
||||||
batch_loss += loss
|
batch_loss += loss
|
||||||
|
|
||||||
|
@ -118,7 +129,18 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
imgs = imgs.to(device)
|
imgs = imgs.to(device)
|
||||||
targets = targets.long().to(device)
|
targets = targets.long().to(device)
|
||||||
|
|
||||||
|
# Print shapes for debugging
|
||||||
|
print(f"\nRotation task shapes:")
|
||||||
|
print(f"Input shape: {imgs.shape}")
|
||||||
|
print(f"Target shape: {targets.shape}")
|
||||||
|
|
||||||
outputs = model(imgs)
|
outputs = model(imgs)
|
||||||
|
print(f"Raw output shape: {outputs.shape}")
|
||||||
|
|
||||||
|
# Reshape output for rotation classification
|
||||||
|
outputs = outputs.view(targets.size(0), -1) # Flatten to [batch_size, features]
|
||||||
|
print(f"Reshaped output shape: {outputs.shape}")
|
||||||
|
|
||||||
loss = criterion_rotate(outputs, targets)
|
loss = criterion_rotate(outputs, targets)
|
||||||
batch_loss += loss
|
batch_loss += loss
|
||||||
|
|
||||||
|
@ -150,6 +172,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
outputs = model(noisy_imgs)
|
outputs = model(noisy_imgs)
|
||||||
|
batch_size = targets.size(0)
|
||||||
|
outputs = outputs.view(batch_size, 3, 224, 224)
|
||||||
loss = criterion_denoise(outputs, targets)
|
loss = criterion_denoise(outputs, targets)
|
||||||
batch_loss += loss
|
batch_loss += loss
|
||||||
|
|
||||||
|
@ -160,6 +184,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
targets = targets.long().to(device)
|
targets = targets.long().to(device)
|
||||||
|
|
||||||
outputs = model(imgs)
|
outputs = model(imgs)
|
||||||
|
outputs = outputs.view(targets.size(0), -1)
|
||||||
loss = criterion_rotate(outputs, targets)
|
loss = criterion_rotate(outputs, targets)
|
||||||
batch_loss += loss
|
batch_loss += loss
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue