finetune_class #1
|
@ -8,24 +8,39 @@ import csv
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
class UpscaleDataset(Dataset):
|
class UpscaleDataset(Dataset):
|
||||||
def __init__(self, parquet_file, transform=None):
|
def __init__(self, parquet_files: list, transform=None):
|
||||||
self.df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000)
|
# Initialize an empty DataFrame to hold the combined data
|
||||||
|
combined_df = pd.DataFrame()
|
||||||
|
|
||||||
|
# Iterate through each Parquet file in the list and load it into a DataFrame
|
||||||
|
for parquet_file in parquet_files:
|
||||||
|
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000)
|
||||||
|
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||||
|
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.df)
|
return len(self.df)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
try:
|
||||||
row = self.df.iloc[idx]
|
row = self.df.iloc[idx]
|
||||||
# Decode the byte strings into images
|
# Convert string to bytes if necessary
|
||||||
low_res_bytes = row['image_512']
|
low_res_bytes = row['image_512'].encode('latin-1') if isinstance(row['image_512'], str) else row['image_512']
|
||||||
high_res_bytes = row['image_1024']
|
high_res_bytes = row['image_1024'].encode('latin-1') if isinstance(row['image_1024'], str) else row['image_1024']
|
||||||
|
|
||||||
|
# Decode the bytes into images
|
||||||
low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
|
low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
|
||||||
high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
|
high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
|
||||||
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
low_res_image = self.transform(low_res_image)
|
low_res_image = self.transform(low_res_image)
|
||||||
high_res_image = self.transform(high_res_image)
|
high_res_image = self.transform(high_res_image)
|
||||||
return low_res_image, high_res_image
|
return low_res_image, high_res_image
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing index {idx}: {str(e)}")
|
||||||
|
# You might want to either skip this sample or return a default value
|
||||||
|
raise e
|
||||||
|
|
||||||
# Example transform: converting PIL images to tensors
|
# Example transform: converting PIL images to tensors
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
|
@ -46,7 +61,7 @@ from torch import nn, optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
# Create your dataset and dataloader
|
# Create your dataset and dataloader
|
||||||
dataset = UpscaleDataset("/root/training_data/vision-dataset/image_upscaler.parquet", transform=transform)
|
dataset = UpscaleDataset(["/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet"], transform=transform)
|
||||||
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
|
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
|
||||||
|
|
||||||
# Define a loss function and optimizer
|
# Define a loss function and optimizer
|
||||||
|
|
Loading…
Reference in New Issue