finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 28 additions and 13 deletions
Showing only changes of commit 2f12fcb863 - Show all commits

View File

@ -8,24 +8,39 @@ import csv
from tqdm import tqdm from tqdm import tqdm
class UpscaleDataset(Dataset): class UpscaleDataset(Dataset):
def __init__(self, parquet_file, transform=None): def __init__(self, parquet_files: list, transform=None):
self.df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000) # Initialize an empty DataFrame to hold the combined data
combined_df = pd.DataFrame()
# Iterate through each Parquet file in the list and load it into a DataFrame
for parquet_file in parquet_files:
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000)
combined_df = pd.concat([combined_df, df], ignore_index=True)
self.transform = transform self.transform = transform
def __len__(self): def __len__(self):
return len(self.df) return len(self.df)
def __getitem__(self, idx): def __getitem__(self, idx):
try:
row = self.df.iloc[idx] row = self.df.iloc[idx]
# Decode the byte strings into images # Convert string to bytes if necessary
low_res_bytes = row['image_512'] low_res_bytes = row['image_512'].encode('latin-1') if isinstance(row['image_512'], str) else row['image_512']
high_res_bytes = row['image_1024'] high_res_bytes = row['image_1024'].encode('latin-1') if isinstance(row['image_1024'], str) else row['image_1024']
# Decode the bytes into images
low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
if self.transform: if self.transform:
low_res_image = self.transform(low_res_image) low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image) high_res_image = self.transform(high_res_image)
return low_res_image, high_res_image return low_res_image, high_res_image
except Exception as e:
print(f"Error processing index {idx}: {str(e)}")
# You might want to either skip this sample or return a default value
raise e
# Example transform: converting PIL images to tensors # Example transform: converting PIL images to tensors
transform = transforms.Compose([ transform = transforms.Compose([
@ -46,7 +61,7 @@ from torch import nn, optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
# Create your dataset and dataloader # Create your dataset and dataloader
dataset = UpscaleDataset("/root/training_data/vision-dataset/image_upscaler.parquet", transform=transform) dataset = UpscaleDataset(["/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet"], transform=transform)
data_loader = DataLoader(dataset, batch_size=16, shuffle=True) data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
# Define a loss function and optimizer # Define a loss function and optimizer