develop #4
|
@ -15,7 +15,7 @@ from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunke
|
||||||
|
|
||||||
class aiuNNDataset(torch.utils.data.Dataset):
|
class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, parquet_path):
|
def __init__(self, parquet_path):
|
||||||
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2500)
|
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2000)
|
||||||
|
|
||||||
self.augmentation = Compose([
|
self.augmentation = Compose([
|
||||||
RandomBrightnessContrast(p=0.5),
|
RandomBrightnessContrast(p=0.5),
|
||||||
|
@ -139,7 +139,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
torch.save(model.state_dict(), "best_model.pth")
|
model.save("best_model")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue