finetune_class #1
|
@ -40,9 +40,11 @@ class UpscaleDataset(Dataset):
|
||||||
def __init__(self, parquet_files: list, transform=None):
|
def __init__(self, parquet_files: list, transform=None):
|
||||||
combined_df = pd.DataFrame()
|
combined_df = pd.DataFrame()
|
||||||
for parquet_file in parquet_files:
|
for parquet_file in parquet_files:
|
||||||
|
# Load a subset (head(2500)) from each parquet file
|
||||||
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(2500)
|
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(2500)
|
||||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||||
|
|
||||||
|
# Validate rows (ensuring each value is bytes or str)
|
||||||
self.df = combined_df.apply(self._validate_row, axis=1)
|
self.df = combined_df.apply(self._validate_row, axis=1)
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.failed_indices = set()
|
self.failed_indices = set()
|
||||||
|
@ -67,6 +69,7 @@ class UpscaleDataset(Dataset):
|
||||||
return len(self.df)
|
return len(self.df)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
# If previous call failed for this index, use a different index.
|
||||||
if idx in self.failed_indices:
|
if idx in self.failed_indices:
|
||||||
return self[(idx + 1) % len(self)]
|
return self[(idx + 1) % len(self)]
|
||||||
try:
|
try:
|
||||||
|
@ -74,10 +77,17 @@ class UpscaleDataset(Dataset):
|
||||||
low_res_bytes = self._decode_image(row['image_512'])
|
low_res_bytes = self._decode_image(row['image_512'])
|
||||||
high_res_bytes = self._decode_image(row['image_1024'])
|
high_res_bytes = self._decode_image(row['image_1024'])
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
|
|
||||||
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
|
# Open image bytes with Pillow and convert to RGBA
|
||||||
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
|
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
|
||||||
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
|
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
|
||||||
|
|
||||||
|
# Resize the images to reduce VRAM usage.
|
||||||
|
# Using Image.ANTIALIAS which is equivalent to LANCZOS in current Pillow versions.
|
||||||
|
low_res = low_res.resize((384, 384), Image.ANTIALIAS)
|
||||||
|
high_res = high_res.resize((768, 768), Image.ANTIALIAS)
|
||||||
|
|
||||||
|
# If a transform is provided (e.g. conversion to Tensor), apply it.
|
||||||
if self.transform:
|
if self.transform:
|
||||||
low_res = self.transform(low_res)
|
low_res = self.transform(low_res)
|
||||||
high_res = self.transform(high_res)
|
high_res = self.transform(high_res)
|
||||||
|
@ -97,7 +107,7 @@ pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
||||||
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||||
model = Upsampler(base_model)
|
model = Upsampler(base_model)
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
# Move model to device using channels_last memory format.
|
# Move model to device using channels_last memory format.
|
||||||
model = model.to(device, memory_format=torch.channels_last)
|
model = model.to(device, memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue