updated loader
This commit is contained in:
parent
a8cd9b00e5
commit
7a1eb8bd30
|
@ -58,16 +58,16 @@ class JPGImageLoader:
|
|||
|
||||
def _get_image(self, item):
|
||||
try:
|
||||
# Retrieve the string data
|
||||
data = item[self.bytes_column]
|
||||
|
||||
# Check if the data is a string, and decode it
|
||||
if isinstance(data, str):
|
||||
bytes_data = base64.b64decode(data) # Adjust decoding as per your data encoding format
|
||||
if isinstance(data, str) and data.startswith("b'"):
|
||||
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
||||
bytes_data = cleaned_data
|
||||
elif isinstance(data, str):
|
||||
bytes_data = base64.b64decode(data)
|
||||
else:
|
||||
bytes_data = data
|
||||
|
||||
# Load the bytes into a BytesIO object and open the image
|
||||
img_bytes = io.BytesIO(bytes_data)
|
||||
image = Image.open(img_bytes).convert("RGB")
|
||||
return image
|
||||
|
@ -75,7 +75,6 @@ class JPGImageLoader:
|
|||
print(f"Error loading image from bytes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_item(self, idx):
|
||||
item = self.dataset.iloc[idx]
|
||||
image = self._get_image(item)
|
||||
|
@ -94,92 +93,59 @@ class JPGImageLoader:
|
|||
print(f"Successfully converted {self.successful_count} images.")
|
||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, dataset,
|
||||
batch_size=32,
|
||||
val_split=0.2,
|
||||
seed=42,
|
||||
column="file_path",
|
||||
label_column=None,
|
||||
**dataloader_kwargs):
|
||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs):
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
random.seed(seed)
|
||||
|
||||
is_bytes_or_bytestring = any(
|
||||
isinstance(value, (bytes, memoryview))
|
||||
for value in dataset[column].dropna().head(1).astype(str)
|
||||
sample_value = dataset[column].iloc[0]
|
||||
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
||||
isinstance(sample_value, bytes) or
|
||||
sample_value.startswith("b'") or
|
||||
sample_value.startswith(('b"', 'data:image'))
|
||||
)
|
||||
|
||||
if is_bytes_or_bytestring:
|
||||
self.loader = JPGImageLoader(
|
||||
dataset,
|
||||
bytes_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||
else:
|
||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
||||
|
||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
|
||||
|
||||
if any(
|
||||
re.match(filepath_pattern, path, flags=re.IGNORECASE)
|
||||
for path in sample_paths
|
||||
):
|
||||
self.loader = FilePathLoader(
|
||||
dataset,
|
||||
file_path_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
||||
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
||||
else:
|
||||
self.loader = JPGImageLoader(
|
||||
dataset,
|
||||
bytes_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||
|
||||
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
|
||||
self.items = []
|
||||
for idx in range(len(dataset)):
|
||||
item = self.loader.get_item(idx)
|
||||
if item is not None:
|
||||
self.items.append(item)
|
||||
|
||||
if not self.items:
|
||||
raise ValueError("No valid items were loaded from the dataset")
|
||||
|
||||
train_indices, val_indices = self._split_data()
|
||||
|
||||
self.train_dataset = self._create_subset(train_indices)
|
||||
self.val_dataset = self._create_subset(val_indices)
|
||||
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
**dataloader_kwargs
|
||||
)
|
||||
|
||||
self.val_loader = DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
**dataloader_kwargs
|
||||
)
|
||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
||||
|
||||
def _split_data(self):
|
||||
if len(self.items) == 0:
|
||||
return [], []
|
||||
raise ValueError("No items to split")
|
||||
|
||||
tasks = [item[1] for item in self.items if len(item) > 1 and hasattr(item, '__getitem__') and item[1] is not None]
|
||||
num_samples = len(self.items)
|
||||
indices = list(range(num_samples))
|
||||
random.shuffle(indices)
|
||||
|
||||
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else []
|
||||
|
||||
train_indices = []
|
||||
val_indices = []
|
||||
|
||||
for task in unique_tasks:
|
||||
task_indices = [i for i, t in enumerate(tasks) if t == task]
|
||||
n_val = int(len(task_indices) * self.val_split)
|
||||
|
||||
random.shuffle(task_indices)
|
||||
|
||||
val_indices.extend(task_indices[:n_val])
|
||||
train_indices.extend(task_indices[n_val:])
|
||||
split_idx = int((1 - self.val_split) * num_samples)
|
||||
train_indices = indices[:split_idx]
|
||||
val_indices = indices[split_idx:]
|
||||
|
||||
return train_indices, val_indices
|
||||
|
||||
|
|
|
@ -18,24 +18,14 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
|
||||
# Create a new AIIAConfig instance
|
||||
config = AIIAConfig(
|
||||
model_name="AIIA-512x",
|
||||
hidden_size=512,
|
||||
num_hidden_layers=12,
|
||||
kernel_size=5,
|
||||
learning_rate=5e-5
|
||||
model_name="AIIA-Base-512x20k",
|
||||
)
|
||||
|
||||
# Initialize the base model
|
||||
model = AIIABase(config)
|
||||
|
||||
# Create dataset loader with merged data
|
||||
aiia_loader = AIIADataLoader(
|
||||
merged_df,
|
||||
batch_size=32,
|
||||
val_split=0.2,
|
||||
seed=42,
|
||||
column="image_bytes"
|
||||
)
|
||||
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32)
|
||||
|
||||
# Access the train and validation loaders
|
||||
train_loader = aiia_loader.train_loader
|
||||
|
|
Loading…
Reference in New Issue