updated loader
This commit is contained in:
parent
a8cd9b00e5
commit
7a1eb8bd30
|
@ -58,23 +58,22 @@ class JPGImageLoader:
|
||||||
|
|
||||||
def _get_image(self, item):
|
def _get_image(self, item):
|
||||||
try:
|
try:
|
||||||
# Retrieve the string data
|
|
||||||
data = item[self.bytes_column]
|
data = item[self.bytes_column]
|
||||||
|
|
||||||
# Check if the data is a string, and decode it
|
if isinstance(data, str) and data.startswith("b'"):
|
||||||
if isinstance(data, str):
|
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
||||||
bytes_data = base64.b64decode(data) # Adjust decoding as per your data encoding format
|
bytes_data = cleaned_data
|
||||||
|
elif isinstance(data, str):
|
||||||
|
bytes_data = base64.b64decode(data)
|
||||||
else:
|
else:
|
||||||
bytes_data = data
|
bytes_data = data
|
||||||
|
|
||||||
# Load the bytes into a BytesIO object and open the image
|
|
||||||
img_bytes = io.BytesIO(bytes_data)
|
img_bytes = io.BytesIO(bytes_data)
|
||||||
image = Image.open(img_bytes).convert("RGB")
|
image = Image.open(img_bytes).convert("RGB")
|
||||||
return image
|
return image
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading image from bytes: {e}")
|
print(f"Error loading image from bytes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_item(self, idx):
|
def get_item(self, idx):
|
||||||
item = self.dataset.iloc[idx]
|
item = self.dataset.iloc[idx]
|
||||||
|
@ -93,94 +92,61 @@ class JPGImageLoader:
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
print(f"Successfully converted {self.successful_count} images.")
|
print(f"Successfully converted {self.successful_count} images.")
|
||||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
|
||||||
|
|
||||||
class AIIADataLoader:
|
class AIIADataLoader:
|
||||||
def __init__(self, dataset,
|
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, **dataloader_kwargs):
|
||||||
batch_size=32,
|
|
||||||
val_split=0.2,
|
|
||||||
seed=42,
|
|
||||||
column="file_path",
|
|
||||||
label_column=None,
|
|
||||||
**dataloader_kwargs):
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
is_bytes_or_bytestring = any(
|
sample_value = dataset[column].iloc[0]
|
||||||
isinstance(value, (bytes, memoryview))
|
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
||||||
for value in dataset[column].dropna().head(1).astype(str)
|
isinstance(sample_value, bytes) or
|
||||||
|
sample_value.startswith("b'") or
|
||||||
|
sample_value.startswith(('b"', 'data:image'))
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_bytes_or_bytestring:
|
if is_bytes_or_bytestring:
|
||||||
self.loader = JPGImageLoader(
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
dataset,
|
|
||||||
bytes_column=column,
|
|
||||||
label_column=label_column
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||||
|
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
||||||
|
|
||||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
|
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
||||||
|
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
||||||
if any(
|
|
||||||
re.match(filepath_pattern, path, flags=re.IGNORECASE)
|
|
||||||
for path in sample_paths
|
|
||||||
):
|
|
||||||
self.loader = FilePathLoader(
|
|
||||||
dataset,
|
|
||||||
file_path_column=column,
|
|
||||||
label_column=label_column
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.loader = JPGImageLoader(
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
dataset,
|
|
||||||
bytes_column=column,
|
|
||||||
label_column=label_column
|
|
||||||
)
|
|
||||||
|
|
||||||
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
|
self.items = []
|
||||||
|
for idx in range(len(dataset)):
|
||||||
|
item = self.loader.get_item(idx)
|
||||||
|
if item is not None:
|
||||||
|
self.items.append(item)
|
||||||
|
|
||||||
|
if not self.items:
|
||||||
|
raise ValueError("No valid items were loaded from the dataset")
|
||||||
|
|
||||||
train_indices, val_indices = self._split_data()
|
train_indices, val_indices = self._split_data()
|
||||||
|
|
||||||
self.train_dataset = self._create_subset(train_indices)
|
self.train_dataset = self._create_subset(train_indices)
|
||||||
self.val_dataset = self._create_subset(val_indices)
|
self.val_dataset = self._create_subset(val_indices)
|
||||||
|
|
||||||
self.train_loader = DataLoader(
|
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
||||||
self.train_dataset,
|
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
**dataloader_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.val_loader = DataLoader(
|
|
||||||
self.val_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
**dataloader_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _split_data(self):
|
def _split_data(self):
|
||||||
if len(self.items) == 0:
|
if len(self.items) == 0:
|
||||||
return [], []
|
raise ValueError("No items to split")
|
||||||
|
|
||||||
tasks = [item[1] for item in self.items if len(item) > 1 and hasattr(item, '__getitem__') and item[1] is not None]
|
num_samples = len(self.items)
|
||||||
|
indices = list(range(num_samples))
|
||||||
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else []
|
random.shuffle(indices)
|
||||||
|
|
||||||
train_indices = []
|
split_idx = int((1 - self.val_split) * num_samples)
|
||||||
val_indices = []
|
train_indices = indices[:split_idx]
|
||||||
|
val_indices = indices[split_idx:]
|
||||||
|
|
||||||
for task in unique_tasks:
|
|
||||||
task_indices = [i for i, t in enumerate(tasks) if t == task]
|
|
||||||
n_val = int(len(task_indices) * self.val_split)
|
|
||||||
|
|
||||||
random.shuffle(task_indices)
|
|
||||||
|
|
||||||
val_indices.extend(task_indices[:n_val])
|
|
||||||
train_indices.extend(task_indices[n_val:])
|
|
||||||
|
|
||||||
return train_indices, val_indices
|
return train_indices, val_indices
|
||||||
|
|
||||||
def _create_subset(self, indices):
|
def _create_subset(self, indices):
|
||||||
|
@ -218,4 +184,4 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
if isinstance(item, Image.Image):
|
if isinstance(item, Image.Image):
|
||||||
return item
|
return item
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid item format.")
|
raise ValueError("Invalid item format.")
|
|
@ -18,24 +18,14 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
# Create a new AIIAConfig instance
|
# Create a new AIIAConfig instance
|
||||||
config = AIIAConfig(
|
config = AIIAConfig(
|
||||||
model_name="AIIA-512x",
|
model_name="AIIA-Base-512x20k",
|
||||||
hidden_size=512,
|
|
||||||
num_hidden_layers=12,
|
|
||||||
kernel_size=5,
|
|
||||||
learning_rate=5e-5
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the base model
|
# Initialize the base model
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
|
|
||||||
# Create dataset loader with merged data
|
# Create dataset loader with merged data
|
||||||
aiia_loader = AIIADataLoader(
|
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32)
|
||||||
merged_df,
|
|
||||||
batch_size=32,
|
|
||||||
val_split=0.2,
|
|
||||||
seed=42,
|
|
||||||
column="image_bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Access the train and validation loaders
|
# Access the train and validation loaders
|
||||||
train_loader = aiia_loader.train_loader
|
train_loader = aiia_loader.train_loader
|
||||||
|
|
Loading…
Reference in New Issue