develop #4
|
@ -15,7 +15,7 @@ from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunke
|
|||
|
||||
class aiuNNDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, parquet_path):
|
||||
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024'])
|
||||
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2500)
|
||||
|
||||
self.augmentation = Compose([
|
||||
RandomBrightnessContrast(p=0.5),
|
||||
|
@ -144,7 +144,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
|||
return model
|
||||
|
||||
def main():
|
||||
BATCH_SIZE = 1
|
||||
BATCH_SIZE = 2
|
||||
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
|
||||
|
||||
if hasattr(model, 'chunked_'):
|
||||
|
|
Loading…
Reference in New Issue