updated loader

This commit is contained in:
Falko Victor Habel 2025-01-26 21:58:13 +01:00
parent a8cd9b00e5
commit 7a1eb8bd30
2 changed files with 36 additions and 80 deletions

View File

@ -58,16 +58,16 @@ class JPGImageLoader:
def _get_image(self, item):
try:
# Retrieve the string data
data = item[self.bytes_column]
# Check if the data is a string, and decode it
if isinstance(data, str):
bytes_data = base64.b64decode(data) # Adjust decoding as per your data encoding format
if isinstance(data, str) and data.startswith("b'"):
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
bytes_data = cleaned_data
elif isinstance(data, str):
bytes_data = base64.b64decode(data)
else:
bytes_data = data
# Load the bytes into a BytesIO object and open the image
img_bytes = io.BytesIO(bytes_data)
image = Image.open(img_bytes).convert("RGB")
return image
@ -75,7 +75,6 @@ class JPGImageLoader:
print(f"Error loading image from bytes: {e}")
return None
def get_item(self, idx):
item = self.dataset.iloc[idx]
image = self._get_image(item)
@ -94,92 +93,59 @@ class JPGImageLoader:
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader:
def __init__(self, dataset,
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
label_column=None,
**dataloader_kwargs):
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs):
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
random.seed(seed)
is_bytes_or_bytestring = any(
isinstance(value, (bytes, memoryview))
for value in dataset[column].dropna().head(1).astype(str)
sample_value = dataset[column].iloc[0]
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
isinstance(sample_value, bytes) or
sample_value.startswith("b'") or
sample_value.startswith(('b"', 'data:image'))
)
if is_bytes_or_bytestring:
self.loader = JPGImageLoader(
dataset,
bytes_column=column,
label_column=label_column
)
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
else:
sample_paths = dataset[column].dropna().head(1).astype(str)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
if any(
re.match(filepath_pattern, path, flags=re.IGNORECASE)
for path in sample_paths
):
self.loader = FilePathLoader(
dataset,
file_path_column=column,
label_column=label_column
)
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
else:
self.loader = JPGImageLoader(
dataset,
bytes_column=column,
label_column=label_column
)
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
self.items = []
for idx in range(len(dataset)):
item = self.loader.get_item(idx)
if item is not None:
self.items.append(item)
if not self.items:
raise ValueError("No valid items were loaded from the dataset")
train_indices, val_indices = self._split_data()
self.train_dataset = self._create_subset(train_indices)
self.val_dataset = self._create_subset(val_indices)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=batch_size,
shuffle=True,
**dataloader_kwargs
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=batch_size,
shuffle=False,
**dataloader_kwargs
)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
def _split_data(self):
if len(self.items) == 0:
return [], []
raise ValueError("No items to split")
tasks = [item[1] for item in self.items if len(item) > 1 and hasattr(item, '__getitem__') and item[1] is not None]
num_samples = len(self.items)
indices = list(range(num_samples))
random.shuffle(indices)
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else []
train_indices = []
val_indices = []
for task in unique_tasks:
task_indices = [i for i, t in enumerate(tasks) if t == task]
n_val = int(len(task_indices) * self.val_split)
random.shuffle(task_indices)
val_indices.extend(task_indices[:n_val])
train_indices.extend(task_indices[n_val:])
split_idx = int((1 - self.val_split) * num_samples)
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
return train_indices, val_indices

View File

@ -18,24 +18,14 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
# Create a new AIIAConfig instance
config = AIIAConfig(
model_name="AIIA-512x",
hidden_size=512,
num_hidden_layers=12,
kernel_size=5,
learning_rate=5e-5
model_name="AIIA-Base-512x20k",
)
# Initialize the base model
model = AIIABase(config)
# Create dataset loader with merged data
aiia_loader = AIIADataLoader(
merged_df,
batch_size=32,
val_split=0.2,
seed=42,
column="image_bytes"
)
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32)
# Access the train and validation loaders
train_loader = aiia_loader.train_loader