From ecb2694415b0a3bc33bd7bca200e37606aa57a6b Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 24 Feb 2025 16:28:18 +0100 Subject: [PATCH] update color channels --- src/aiunn/finetune.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index e03960b..03efb8f 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -77,16 +77,25 @@ class UpscaleDataset(Dataset): low_res_bytes = self._decode_image(row['image_512']) high_res_bytes = self._decode_image(row['image_1024']) ImageFile.LOAD_TRUNCATED_IMAGES = True - - # Open image bytes with Pillow and convert to RGBA - low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') - high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA') - + # Open image bytes with Pillow and convert to RGBA first + low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') + high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA') + + # Create a new RGB image with black background + low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0)) + high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0)) + + # Composite the original image over the black background + low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3]) + high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3]) + + # Now we have true 3-channel RGB images with transparent areas converted to black + low_res = low_res_rgb + high_res = high_res_rgb + # Resize the images to reduce VRAM usage. - # Using Image.ANTIALIAS which is equivalent to LANCZOS in current Pillow versions. low_res = low_res.resize((384, 384), Image.LANCZOS) high_res = high_res.resize((768, 768), Image.LANCZOS) - # If a transform is provided (e.g. conversion to Tensor), apply it. if self.transform: low_res = self.transform(low_res) @@ -96,7 +105,7 @@ class UpscaleDataset(Dataset): print(f"\nError at index {idx}: {str(e)}") self.failed_indices.add(idx) return self[(idx + 1) % len(self)] - + # Define any transformations you require. transform = transforms.Compose([ transforms.ToTensor(),