downsized imageset
This commit is contained in:
parent
443b9f5589
commit
2360e23cc7
|
@ -40,9 +40,11 @@ class UpscaleDataset(Dataset):
|
|||
def __init__(self, parquet_files: list, transform=None):
|
||||
combined_df = pd.DataFrame()
|
||||
for parquet_file in parquet_files:
|
||||
# Load a subset (head(2500)) from each parquet file
|
||||
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(2500)
|
||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||
|
||||
|
||||
# Validate rows (ensuring each value is bytes or str)
|
||||
self.df = combined_df.apply(self._validate_row, axis=1)
|
||||
self.transform = transform
|
||||
self.failed_indices = set()
|
||||
|
@ -67,6 +69,7 @@ class UpscaleDataset(Dataset):
|
|||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# If previous call failed for this index, use a different index.
|
||||
if idx in self.failed_indices:
|
||||
return self[(idx + 1) % len(self)]
|
||||
try:
|
||||
|
@ -74,10 +77,17 @@ class UpscaleDataset(Dataset):
|
|||
low_res_bytes = self._decode_image(row['image_512'])
|
||||
high_res_bytes = self._decode_image(row['image_1024'])
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
|
||||
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
|
||||
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
|
||||
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
|
||||
|
||||
# Open image bytes with Pillow and convert to RGBA
|
||||
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
|
||||
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
|
||||
|
||||
# Resize the images to reduce VRAM usage.
|
||||
# Using Image.ANTIALIAS which is equivalent to LANCZOS in current Pillow versions.
|
||||
low_res = low_res.resize((384, 384), Image.ANTIALIAS)
|
||||
high_res = high_res.resize((768, 768), Image.ANTIALIAS)
|
||||
|
||||
# If a transform is provided (e.g. conversion to Tensor), apply it.
|
||||
if self.transform:
|
||||
low_res = self.transform(low_res)
|
||||
high_res = self.transform(high_res)
|
||||
|
@ -97,7 +107,7 @@ pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
|||
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||
model = Upsampler(base_model)
|
||||
|
||||
device = torch.device("cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# Move model to device using channels_last memory format.
|
||||
model = model.to(device, memory_format=torch.channels_last)
|
||||
|
||||
|
|
Loading…
Reference in New Issue