From 7a1eb8bd3098360035b835d83a5a8a7dfe091bca Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 26 Jan 2025 21:58:13 +0100 Subject: [PATCH] updated loader --- src/aiia/data/DataLoader.py | 102 ++++++++++++------------------------ src/pretrain.py | 14 +---- 2 files changed, 36 insertions(+), 80 deletions(-) diff --git a/src/aiia/data/DataLoader.py b/src/aiia/data/DataLoader.py index 98b882b..223d146 100644 --- a/src/aiia/data/DataLoader.py +++ b/src/aiia/data/DataLoader.py @@ -58,23 +58,22 @@ class JPGImageLoader: def _get_image(self, item): try: - # Retrieve the string data data = item[self.bytes_column] - # Check if the data is a string, and decode it - if isinstance(data, str): - bytes_data = base64.b64decode(data) # Adjust decoding as per your data encoding format + if isinstance(data, str) and data.startswith("b'"): + cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1') + bytes_data = cleaned_data + elif isinstance(data, str): + bytes_data = base64.b64decode(data) else: bytes_data = data - # Load the bytes into a BytesIO object and open the image img_bytes = io.BytesIO(bytes_data) image = Image.open(img_bytes).convert("RGB") return image except Exception as e: print(f"Error loading image from bytes: {e}") return None - def get_item(self, idx): item = self.dataset.iloc[idx] @@ -93,94 +92,61 @@ class JPGImageLoader: def print_summary(self): print(f"Successfully converted {self.successful_count} images.") print(f"Skipped {self.skipped_count} images due to errors.") - class AIIADataLoader: - def __init__(self, dataset, - batch_size=32, - val_split=0.2, - seed=42, - column="file_path", - label_column=None, - **dataloader_kwargs): + def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs): self.batch_size = batch_size self.val_split = val_split self.seed = seed random.seed(seed) - is_bytes_or_bytestring = any( - isinstance(value, (bytes, memoryview)) - for value in dataset[column].dropna().head(1).astype(str) + sample_value = dataset[column].iloc[0] + is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and ( + isinstance(sample_value, bytes) or + sample_value.startswith("b'") or + sample_value.startswith(('b"', 'data:image')) ) if is_bytes_or_bytestring: - self.loader = JPGImageLoader( - dataset, - bytes_column=column, - label_column=label_column - ) + self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column) else: sample_paths = dataset[column].dropna().head(1).astype(str) + filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$' - filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$' - - if any( - re.match(filepath_pattern, path, flags=re.IGNORECASE) - for path in sample_paths - ): - self.loader = FilePathLoader( - dataset, - file_path_column=column, - label_column=label_column - ) + if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths): + self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column) else: - self.loader = JPGImageLoader( - dataset, - bytes_column=column, - label_column=label_column - ) + self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column) - self.items = [self.loader.get_item(idx) for idx in range(len(dataset))] + self.items = [] + for idx in range(len(dataset)): + item = self.loader.get_item(idx) + if item is not None: + self.items.append(item) + if not self.items: + raise ValueError("No valid items were loaded from the dataset") + train_indices, val_indices = self._split_data() self.train_dataset = self._create_subset(train_indices) self.val_dataset = self._create_subset(val_indices) - self.train_loader = DataLoader( - self.train_dataset, - batch_size=batch_size, - shuffle=True, - **dataloader_kwargs - ) - - self.val_loader = DataLoader( - self.val_dataset, - batch_size=batch_size, - shuffle=False, - **dataloader_kwargs - ) + self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs) + self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs) def _split_data(self): if len(self.items) == 0: - return [], [] + raise ValueError("No items to split") - tasks = [item[1] for item in self.items if len(item) > 1 and hasattr(item, '__getitem__') and item[1] is not None] - - unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else [] + num_samples = len(self.items) + indices = list(range(num_samples)) + random.shuffle(indices) - train_indices = [] - val_indices = [] + split_idx = int((1 - self.val_split) * num_samples) + train_indices = indices[:split_idx] + val_indices = indices[split_idx:] - for task in unique_tasks: - task_indices = [i for i, t in enumerate(tasks) if t == task] - n_val = int(len(task_indices) * self.val_split) - - random.shuffle(task_indices) - - val_indices.extend(task_indices[:n_val]) - train_indices.extend(task_indices[n_val:]) - return train_indices, val_indices def _create_subset(self, indices): @@ -218,4 +184,4 @@ class AIIADataset(torch.utils.data.Dataset): if isinstance(item, Image.Image): return item else: - raise ValueError("Invalid item format.") + raise ValueError("Invalid item format.") \ No newline at end of file diff --git a/src/pretrain.py b/src/pretrain.py index 09e9856..78ce63a 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -18,24 +18,14 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): # Create a new AIIAConfig instance config = AIIAConfig( - model_name="AIIA-512x", - hidden_size=512, - num_hidden_layers=12, - kernel_size=5, - learning_rate=5e-5 + model_name="AIIA-Base-512x20k", ) # Initialize the base model model = AIIABase(config) # Create dataset loader with merged data - aiia_loader = AIIADataLoader( - merged_df, - batch_size=32, - val_split=0.2, - seed=42, - column="image_bytes" - ) + aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32) # Access the train and validation loaders train_loader = aiia_loader.train_loader