finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 2 additions and 2 deletions
Showing only changes of commit 19e5b72724 - Show all commits

View File

@ -15,7 +15,7 @@ from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunke
class aiuNNDataset(torch.utils.data.Dataset): class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path): def __init__(self, parquet_path):
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']) self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2500)
self.augmentation = Compose([ self.augmentation = Compose([
RandomBrightnessContrast(p=0.5), RandomBrightnessContrast(p=0.5),
@ -144,7 +144,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
return model return model
def main(): def main():
BATCH_SIZE = 1 BATCH_SIZE = 2
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
if hasattr(model, 'chunked_'): if hasattr(model, 'chunked_'):