finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 43 additions and 81 deletions
Showing only changes of commit 8fafbebe45 - Show all commits

View File

@ -8,23 +8,22 @@ from albumentations.pytorch import ToTensorV2
from PIL import Image, ImageFile from PIL import Image, ImageFile
import io import io
import base64 import base64
import numpy as np
from torch import nn from torch import nn
# Import the model and config from your existing code from torch.utils.data import random_split
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
class aiuNNDataset(torch.utils.data.Dataset): class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path): def __init__(self, parquet_path):
# Read the Parquet file self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024'])
self.df = pd.read_parquet(parquet_path).head(1250)
# Data augmentation pipeline without Resize as it's redundant
self.augmentation = Compose([ self.augmentation = Compose([
RandomBrightnessContrast(), RandomBrightnessContrast(p=0.5),
HorizontalFlip(p=0.5), HorizontalFlip(p=0.5),
VerticalFlip(p=0.5), VerticalFlip(p=0.5),
Rotate(degrees=45), Rotate(limit=45, p=0.5),
GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), GaussianBlur(blur_limit=(3, 7), p=0.5),
Normalize(mean=[0.5], std=[0.5]), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ToTensorV2() ToTensorV2()
]) ])
@ -33,92 +32,66 @@ class aiuNNDataset(torch.utils.data.Dataset):
def load_image(self, image_data): def load_image(self, image_data):
try: try:
# Handle both bytes and base64 encoded strings
if isinstance(image_data, str): if isinstance(image_data, str):
# Decode base64 string to bytes
image_data = base64.b64decode(image_data) image_data = base64.b64decode(image_data)
# Verify data is valid before creating BytesIO
if not isinstance(image_data, bytes): if not isinstance(image_data, bytes):
raise ValueError("Invalid image data format") raise ValueError("Invalid image data format")
# Create image stream
image_stream = io.BytesIO(image_data) image_stream = io.BytesIO(image_data)
# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert image to RGB
image = Image.open(image_stream).convert('RGB') image = Image.open(image_stream).convert('RGB')
image_array = np.array(image)
# Create fresh copy for verify() since it modifies the image object return image_array
image_verify = image.copy()
# Verify image is valid
try:
image_verify.verify()
except Exception as e:
raise ValueError(f"Image verification failed: {str(e)}")
finally:
image_verify.close()
return image
except Exception as e: except Exception as e:
raise RuntimeError(f"Error loading image: {str(e)}") raise RuntimeError(f"Error loading image: {str(e)}")
finally: finally:
# Ensure stream is closed
if 'image_stream' in locals(): if 'image_stream' in locals():
image_stream.close() image_stream.close()
def __getitem__(self, idx): def __getitem__(self, idx):
row = self.df.iloc[idx] row = self.df.iloc[idx]
# Load images using the new method
low_res_image = self.load_image(row['image_512']) low_res_image = self.load_image(row['image_512'])
high_res_image = self.load_image(row['image_1024']) high_res_image = self.load_image(row['image_1024'])
# Apply augmentation and normalization
augmented_low = self.augmentation(image=low_res_image) augmented_low = self.augmentation(image=low_res_image)
low_res = augmented_low['image']
augmented_high = self.augmentation(image=high_res_image) augmented_high = self.augmentation(image=high_res_image)
high_res = augmented_high['image']
return { return {
'low_res': low_res, 'low_res': augmented_low['image'],
'high_res': high_res 'high_res': augmented_high['image']
} }
from torch.utils.data.dataset import ConcatDataset def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
# Load all datasets and concatenate them
loaded_datasets = [aiuNNDataset(d) for d in datasets] loaded_datasets = [aiuNNDataset(d) for d in datasets]
combined_dataset = ConcatDataset(loaded_datasets) combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
# Split into training and validation sets train_size = int(0.8 * len(combined_dataset))
train_dataset, val_dataset = combined_dataset.train_val_split() val_size = len(combined_dataset) - train_size
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
num_workers=4 num_workers=4,
pin_memory=True,
persistent_workers=True
) )
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
val_dataset, val_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
num_workers=4 num_workers=4,
pin_memory=True,
persistent_workers=True
) )
# Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cuda' if torch.cuda.is_available() else 'cpu' model = model.to(device)
model.to(device)
# Define loss function and optimizer
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
@ -128,73 +101,62 @@ def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
for epoch in range(epochs): for epoch in range(epochs):
model.train() model.train()
train_loss = 0.0 train_loss = 0.0
for batch_idx, batch in enumerate(tqdm(train_loader)): for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"):
# Your training code here if torch.cuda.is_available():
torch.cuda.empty_cache()
low_res = batch['low_res'].to(device) low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
# Forward pass
outputs = model(low_res)
# Calculate loss
loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) # Adjust for channel dimensions
# Backward pass and optimize
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(low_res)
loss = criterion(outputs, high_res)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
train_loss += loss.item() train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader) avg_train_loss = train_loss / len(train_loader)
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
# Validation
model.eval() model.eval()
val_loss = 0.0 val_loss = 0.0
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(val_loader, desc="Validation"): for batch in tqdm(val_loader, desc="Validation"):
if torch.cuda.is_available():
torch.cuda.empty_cache()
low_res = batch['low_res'].to(device) low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
outputs = model(low_res) outputs = model(low_res)
loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) loss = criterion(outputs, high_res)
val_loss += loss.item() val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader) avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
# Save best model
if avg_val_loss < best_val_loss: if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss best_val_loss = avg_val_loss
model.save("best_model") torch.save(model.state_dict(), "best_model.pth")
return model return model
def main(): def main():
# Paths to your data BATCH_SIZE = 1
train_parquet_path = "/root/training_data/vision-dataset/image_upscaler.parquet"
val_parquet_path = "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
# Load pretrained model
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512") model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
# Add final upsampling layer if needed (depending on your specific architecture)
if hasattr(model, 'chunked_'): if hasattr(model, 'chunked_'):
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
# Fine-tune
finetune_model( finetune_model(
model, model=model,
train_parquet_path, datasets=[
val_parquet_path "/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
],
batch_size=BATCH_SIZE
) )
if __name__ == '__main__': if __name__ == '__main__':