finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 17 additions and 8 deletions
Showing only changes of commit ecb2694415 - Show all commits

View File

@ -77,16 +77,25 @@ class UpscaleDataset(Dataset):
low_res_bytes = self._decode_image(row['image_512']) low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024']) high_res_bytes = self._decode_image(row['image_1024'])
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
# Open image bytes with Pillow and convert to RGBA first
# Open image bytes with Pillow and convert to RGBA low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
# Create a new RGB image with black background
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
# Composite the original image over the black background
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
# Now we have true 3-channel RGB images with transparent areas converted to black
low_res = low_res_rgb
high_res = high_res_rgb
# Resize the images to reduce VRAM usage. # Resize the images to reduce VRAM usage.
# Using Image.ANTIALIAS which is equivalent to LANCZOS in current Pillow versions.
low_res = low_res.resize((384, 384), Image.LANCZOS) low_res = low_res.resize((384, 384), Image.LANCZOS)
high_res = high_res.resize((768, 768), Image.LANCZOS) high_res = high_res.resize((768, 768), Image.LANCZOS)
# If a transform is provided (e.g. conversion to Tensor), apply it. # If a transform is provided (e.g. conversion to Tensor), apply it.
if self.transform: if self.transform:
low_res = self.transform(low_res) low_res = self.transform(low_res)
@ -96,7 +105,7 @@ class UpscaleDataset(Dataset):
print(f"\nError at index {idx}: {str(e)}") print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx) self.failed_indices.add(idx)
return self[(idx + 1) % len(self)] return self[(idx + 1) % len(self)]
# Define any transformations you require. # Define any transformations you require.
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),