finetune_class #1
|
@ -8,23 +8,22 @@ from albumentations.pytorch import ToTensorV2
|
||||||
from PIL import Image, ImageFile
|
from PIL import Image, ImageFile
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
# Import the model and config from your existing code
|
from torch.utils.data import random_split
|
||||||
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
||||||
|
|
||||||
class aiuNNDataset(torch.utils.data.Dataset):
|
class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, parquet_path):
|
def __init__(self, parquet_path):
|
||||||
# Read the Parquet file
|
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024'])
|
||||||
self.df = pd.read_parquet(parquet_path).head(1250)
|
|
||||||
|
|
||||||
# Data augmentation pipeline without Resize as it's redundant
|
|
||||||
self.augmentation = Compose([
|
self.augmentation = Compose([
|
||||||
RandomBrightnessContrast(),
|
RandomBrightnessContrast(p=0.5),
|
||||||
HorizontalFlip(p=0.5),
|
HorizontalFlip(p=0.5),
|
||||||
VerticalFlip(p=0.5),
|
VerticalFlip(p=0.5),
|
||||||
Rotate(degrees=45),
|
Rotate(limit=45, p=0.5),
|
||||||
GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
|
GaussianBlur(blur_limit=(3, 7), p=0.5),
|
||||||
Normalize(mean=[0.5], std=[0.5]),
|
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
ToTensorV2()
|
ToTensorV2()
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@ -33,92 +32,66 @@ class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
def load_image(self, image_data):
|
def load_image(self, image_data):
|
||||||
try:
|
try:
|
||||||
# Handle both bytes and base64 encoded strings
|
|
||||||
if isinstance(image_data, str):
|
if isinstance(image_data, str):
|
||||||
# Decode base64 string to bytes
|
|
||||||
image_data = base64.b64decode(image_data)
|
image_data = base64.b64decode(image_data)
|
||||||
|
|
||||||
# Verify data is valid before creating BytesIO
|
|
||||||
if not isinstance(image_data, bytes):
|
if not isinstance(image_data, bytes):
|
||||||
raise ValueError("Invalid image data format")
|
raise ValueError("Invalid image data format")
|
||||||
|
|
||||||
# Create image stream
|
|
||||||
image_stream = io.BytesIO(image_data)
|
image_stream = io.BytesIO(image_data)
|
||||||
|
|
||||||
# Enable loading of truncated images
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
# Load and convert image to RGB
|
|
||||||
image = Image.open(image_stream).convert('RGB')
|
image = Image.open(image_stream).convert('RGB')
|
||||||
|
image_array = np.array(image)
|
||||||
|
|
||||||
# Create fresh copy for verify() since it modifies the image object
|
return image_array
|
||||||
image_verify = image.copy()
|
|
||||||
|
|
||||||
# Verify image is valid
|
|
||||||
try:
|
|
||||||
image_verify.verify()
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Image verification failed: {str(e)}")
|
|
||||||
finally:
|
|
||||||
image_verify.close()
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error loading image: {str(e)}")
|
raise RuntimeError(f"Error loading image: {str(e)}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Ensure stream is closed
|
|
||||||
if 'image_stream' in locals():
|
if 'image_stream' in locals():
|
||||||
image_stream.close()
|
image_stream.close()
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
row = self.df.iloc[idx]
|
row = self.df.iloc[idx]
|
||||||
|
|
||||||
# Load images using the new method
|
|
||||||
low_res_image = self.load_image(row['image_512'])
|
low_res_image = self.load_image(row['image_512'])
|
||||||
high_res_image = self.load_image(row['image_1024'])
|
high_res_image = self.load_image(row['image_1024'])
|
||||||
|
|
||||||
# Apply augmentation and normalization
|
|
||||||
augmented_low = self.augmentation(image=low_res_image)
|
augmented_low = self.augmentation(image=low_res_image)
|
||||||
low_res = augmented_low['image']
|
|
||||||
|
|
||||||
augmented_high = self.augmentation(image=high_res_image)
|
augmented_high = self.augmentation(image=high_res_image)
|
||||||
high_res = augmented_high['image']
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'low_res': low_res,
|
'low_res': augmented_low['image'],
|
||||||
'high_res': high_res
|
'high_res': augmented_high['image']
|
||||||
}
|
}
|
||||||
from torch.utils.data.dataset import ConcatDataset
|
def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
|
|
||||||
def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
|
|
||||||
# Load all datasets and concatenate them
|
|
||||||
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
||||||
combined_dataset = ConcatDataset(loaded_datasets)
|
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
|
||||||
|
|
||||||
# Split into training and validation sets
|
train_size = int(0.8 * len(combined_dataset))
|
||||||
train_dataset, val_dataset = combined_dataset.train_val_split()
|
val_size = len(combined_dataset) - train_size
|
||||||
|
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4
|
num_workers=4,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = torch.utils.data.DataLoader(
|
val_loader = torch.utils.data.DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=4
|
num_workers=4,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set device
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
model = model.to(device)
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
# Define loss function and optimizer
|
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
|
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
|
||||||
|
|
||||||
|
@ -128,73 +101,62 @@ def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
train_loss = 0.0
|
train_loss = 0.0
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(tqdm(train_loader)):
|
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"):
|
||||||
# Your training code here
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
low_res = batch['low_res'].to(device)
|
low_res = batch['low_res'].to(device)
|
||||||
high_res = batch['high_res'].to(device)
|
high_res = batch['high_res'].to(device)
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
outputs = model(low_res)
|
|
||||||
|
|
||||||
# Calculate loss
|
|
||||||
loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) # Adjust for channel dimensions
|
|
||||||
|
|
||||||
# Backward pass and optimize
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
outputs = model(low_res)
|
||||||
|
loss = criterion(outputs, high_res)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
train_loss += loss.item()
|
train_loss += loss.item()
|
||||||
|
|
||||||
avg_train_loss = train_loss / len(train_loader)
|
avg_train_loss = train_loss / len(train_loader)
|
||||||
|
|
||||||
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
|
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
|
||||||
|
|
||||||
# Validation
|
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0.0
|
val_loss = 0.0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in tqdm(val_loader, desc="Validation"):
|
for batch in tqdm(val_loader, desc="Validation"):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
low_res = batch['low_res'].to(device)
|
low_res = batch['low_res'].to(device)
|
||||||
high_res = batch['high_res'].to(device)
|
high_res = batch['high_res'].to(device)
|
||||||
|
|
||||||
outputs = model(low_res)
|
outputs = model(low_res)
|
||||||
loss = criterion(outputs, high_res.permute(0, 3, 1, 2))
|
loss = criterion(outputs, high_res)
|
||||||
|
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(val_loader)
|
avg_val_loss = val_loss / len(val_loader)
|
||||||
|
|
||||||
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
||||||
|
|
||||||
# Save best model
|
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
model.save("best_model")
|
torch.save(model.state_dict(), "best_model.pth")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Paths to your data
|
BATCH_SIZE = 1
|
||||||
train_parquet_path = "/root/training_data/vision-dataset/image_upscaler.parquet"
|
|
||||||
val_parquet_path = "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
|
||||||
|
|
||||||
# Load pretrained model
|
|
||||||
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
|
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512")
|
||||||
|
|
||||||
# Add final upsampling layer if needed (depending on your specific architecture)
|
|
||||||
if hasattr(model, 'chunked_'):
|
if hasattr(model, 'chunked_'):
|
||||||
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
|
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
|
||||||
|
|
||||||
# Fine-tune
|
|
||||||
finetune_model(
|
finetune_model(
|
||||||
model,
|
model=model,
|
||||||
train_parquet_path,
|
datasets=[
|
||||||
val_parquet_path
|
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||||
|
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||||
|
],
|
||||||
|
batch_size=BATCH_SIZE
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue