updated loader

This commit is contained in:
Falko Victor Habel 2025-01-26 21:58:13 +01:00
parent a8cd9b00e5
commit 7a1eb8bd30
2 changed files with 36 additions and 80 deletions

View File

@ -58,23 +58,22 @@ class JPGImageLoader:
def _get_image(self, item): def _get_image(self, item):
try: try:
# Retrieve the string data
data = item[self.bytes_column] data = item[self.bytes_column]
# Check if the data is a string, and decode it if isinstance(data, str) and data.startswith("b'"):
if isinstance(data, str): cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
bytes_data = base64.b64decode(data) # Adjust decoding as per your data encoding format bytes_data = cleaned_data
elif isinstance(data, str):
bytes_data = base64.b64decode(data)
else: else:
bytes_data = data bytes_data = data
# Load the bytes into a BytesIO object and open the image
img_bytes = io.BytesIO(bytes_data) img_bytes = io.BytesIO(bytes_data)
image = Image.open(img_bytes).convert("RGB") image = Image.open(img_bytes).convert("RGB")
return image return image
except Exception as e: except Exception as e:
print(f"Error loading image from bytes: {e}") print(f"Error loading image from bytes: {e}")
return None return None
def get_item(self, idx): def get_item(self, idx):
item = self.dataset.iloc[idx] item = self.dataset.iloc[idx]
@ -93,94 +92,61 @@ class JPGImageLoader:
def print_summary(self): def print_summary(self):
print(f"Successfully converted {self.successful_count} images.") print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.") print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader: class AIIADataLoader:
def __init__(self, dataset, def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs):
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None,
**dataloader_kwargs):
self.batch_size = batch_size self.batch_size = batch_size
self.val_split = val_split self.val_split = val_split
self.seed = seed self.seed = seed
random.seed(seed) random.seed(seed)
is_bytes_or_bytestring = any( sample_value = dataset[column].iloc[0]
isinstance(value, (bytes, memoryview)) is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
for value in dataset[column].dropna().head(1).astype(str) isinstance(sample_value, bytes) or
sample_value.startswith("b'") or
sample_value.startswith(('b"', 'data:image'))
) )
if is_bytes_or_bytestring: if is_bytes_or_bytestring:
self.loader = JPGImageLoader( self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
dataset,
bytes_column=column,
label_column=label_column
)
else: else:
sample_paths = dataset[column].dropna().head(1).astype(str) sample_paths = dataset[column].dropna().head(1).astype(str)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$' if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
if any(
re.match(filepath_pattern, path, flags=re.IGNORECASE)
for path in sample_paths
):
self.loader = FilePathLoader(
dataset,
file_path_column=column,
label_column=label_column
)
else: else:
self.loader = JPGImageLoader( self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
dataset,
bytes_column=column,
label_column=label_column
)
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))] self.items = []
for idx in range(len(dataset)):
item = self.loader.get_item(idx)
if item is not None:
self.items.append(item)
if not self.items:
raise ValueError("No valid items were loaded from the dataset")
train_indices, val_indices = self._split_data() train_indices, val_indices = self._split_data()
self.train_dataset = self._create_subset(train_indices) self.train_dataset = self._create_subset(train_indices)
self.val_dataset = self._create_subset(val_indices) self.val_dataset = self._create_subset(val_indices)
self.train_loader = DataLoader( self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
self.train_dataset, self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
batch_size=batch_size,
shuffle=True,
**dataloader_kwargs
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=batch_size,
shuffle=False,
**dataloader_kwargs
)
def _split_data(self): def _split_data(self):
if len(self.items) == 0: if len(self.items) == 0:
return [], [] raise ValueError("No items to split")
tasks = [item[1] for item in self.items if len(item) > 1 and hasattr(item, '__getitem__') and item[1] is not None] num_samples = len(self.items)
indices = list(range(num_samples))
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else [] random.shuffle(indices)
train_indices = [] split_idx = int((1 - self.val_split) * num_samples)
val_indices = [] train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
for task in unique_tasks:
task_indices = [i for i, t in enumerate(tasks) if t == task]
n_val = int(len(task_indices) * self.val_split)
random.shuffle(task_indices)
val_indices.extend(task_indices[:n_val])
train_indices.extend(task_indices[n_val:])
return train_indices, val_indices return train_indices, val_indices
def _create_subset(self, indices): def _create_subset(self, indices):
@ -218,4 +184,4 @@ class AIIADataset(torch.utils.data.Dataset):
if isinstance(item, Image.Image): if isinstance(item, Image.Image):
return item return item
else: else:
raise ValueError("Invalid item format.") raise ValueError("Invalid item format.")

View File

@ -18,24 +18,14 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
# Create a new AIIAConfig instance # Create a new AIIAConfig instance
config = AIIAConfig( config = AIIAConfig(
model_name="AIIA-512x", model_name="AIIA-Base-512x20k",
hidden_size=512,
num_hidden_layers=12,
kernel_size=5,
learning_rate=5e-5
) )
# Initialize the base model # Initialize the base model
model = AIIABase(config) model = AIIABase(config)
# Create dataset loader with merged data # Create dataset loader with merged data
aiia_loader = AIIADataLoader( aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32)
merged_df,
batch_size=32,
val_split=0.2,
seed=42,
column="image_bytes"
)
# Access the train and validation loaders # Access the train and validation loaders
train_loader = aiia_loader.train_loader train_loader = aiia_loader.train_loader