From 2360e23cc776084521009b9161a7f46bf8443166 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 24 Feb 2025 16:20:45 +0100 Subject: [PATCH] downsized imageset --- src/aiunn/finetune.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 59e3bf9..64a4ac1 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -40,9 +40,11 @@ class UpscaleDataset(Dataset): def __init__(self, parquet_files: list, transform=None): combined_df = pd.DataFrame() for parquet_file in parquet_files: + # Load a subset (head(2500)) from each parquet file df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(2500) combined_df = pd.concat([combined_df, df], ignore_index=True) - + + # Validate rows (ensuring each value is bytes or str) self.df = combined_df.apply(self._validate_row, axis=1) self.transform = transform self.failed_indices = set() @@ -67,6 +69,7 @@ class UpscaleDataset(Dataset): return len(self.df) def __getitem__(self, idx): + # If previous call failed for this index, use a different index. if idx in self.failed_indices: return self[(idx + 1) % len(self)] try: @@ -74,10 +77,17 @@ class UpscaleDataset(Dataset): low_res_bytes = self._decode_image(row['image_512']) high_res_bytes = self._decode_image(row['image_1024']) ImageFile.LOAD_TRUNCATED_IMAGES = True - low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') - high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') - if low_res.size != (512, 512) or high_res.size != (1024, 1024): - raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}") + + # Open image bytes with Pillow and convert to RGBA + low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') + high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA') + + # Resize the images to reduce VRAM usage. + # Using Image.ANTIALIAS which is equivalent to LANCZOS in current Pillow versions. + low_res = low_res.resize((384, 384), Image.ANTIALIAS) + high_res = high_res.resize((768, 768), Image.ANTIALIAS) + + # If a transform is provided (e.g. conversion to Tensor), apply it. if self.transform: low_res = self.transform(low_res) high_res = self.transform(high_res) @@ -97,7 +107,7 @@ pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" base_model = AIIABase.load(pretrained_model_path, precision="bf16") model = Upsampler(base_model) -device = torch.device("cpu") +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Move model to device using channels_last memory format. model = model.to(device, memory_format=torch.channels_last)