Merge pull request 'finetune_class' (#1) from finetune_class into develop
Reviewed-on: #1
This commit is contained in:
commit
930f3a4885
|
@ -0,0 +1,106 @@
|
|||
from aiia import AIIABase
|
||||
from aiunn import aiuNN
|
||||
from aiunn import aiuNNTrainer
|
||||
import pandas as pd
|
||||
import io
|
||||
import base64
|
||||
from PIL import Image, ImageFile
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
|
||||
class UpscaleDataset(Dataset):
|
||||
def __init__(self, parquet_files: list, transform=None, samples_per_file=2500):
|
||||
combined_df = pd.DataFrame()
|
||||
for parquet_file in parquet_files:
|
||||
# Load a subset from each parquet file
|
||||
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(samples_per_file)
|
||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||
|
||||
# Validate rows (ensuring each value is bytes or str)
|
||||
self.df = combined_df.apply(self._validate_row, axis=1)
|
||||
self.transform = transform
|
||||
self.failed_indices = set()
|
||||
|
||||
def _validate_row(self, row):
|
||||
for col in ['image_512', 'image_1024']:
|
||||
if not isinstance(row[col], (bytes, str)):
|
||||
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
|
||||
return row
|
||||
|
||||
def _decode_image(self, data):
|
||||
try:
|
||||
if isinstance(data, str):
|
||||
return base64.b64decode(data)
|
||||
elif isinstance(data, bytes):
|
||||
return data
|
||||
raise ValueError(f"Unsupported data type: {type(data)}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Decoding failed: {str(e)}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# If previous call failed for this index, use a different index
|
||||
if idx in self.failed_indices:
|
||||
return self[(idx + 1) % len(self)]
|
||||
try:
|
||||
row = self.df.iloc[idx]
|
||||
low_res_bytes = self._decode_image(row['image_512'])
|
||||
high_res_bytes = self._decode_image(row['image_1024'])
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
# Open image bytes with Pillow and convert to RGBA first
|
||||
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
|
||||
high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
|
||||
|
||||
# Create a new RGB image with black background
|
||||
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
|
||||
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
|
||||
|
||||
# Composite the original image over the black background
|
||||
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
|
||||
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
|
||||
|
||||
# Now we have true 3-channel RGB images with transparent areas converted to black
|
||||
low_res = low_res_rgb
|
||||
high_res = high_res_rgb
|
||||
|
||||
# Resize the images to reduce VRAM usage
|
||||
low_res = low_res.resize((384, 384), Image.LANCZOS)
|
||||
high_res = high_res.resize((768, 768), Image.LANCZOS)
|
||||
|
||||
# If a transform is provided (e.g. conversion to Tensor), apply it
|
||||
if self.transform:
|
||||
low_res = self.transform(low_res)
|
||||
high_res = self.transform(high_res)
|
||||
return low_res, high_res
|
||||
except Exception as e:
|
||||
print(f"\nError at index {idx}: {str(e)}")
|
||||
self.failed_indices.add(idx)
|
||||
return self[(idx + 1) % len(self)]
|
||||
|
||||
|
||||
if __name__ =="__main__":
|
||||
# Load your base model and upscaler
|
||||
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
||||
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||
upscaler = aiuNN(base_model)
|
||||
|
||||
# Create trainer with your dataset class
|
||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||
|
||||
# Load data using parameters for your dataset
|
||||
dataset_params = {
|
||||
'parquet_files': [
|
||||
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||
],
|
||||
'transform': transforms.Compose([transforms.ToTensor()]),
|
||||
'samples_per_file': 5000
|
||||
}
|
||||
trainer.load_data(dataset_params=dataset_params, batch_size=1)
|
||||
|
||||
# Fine-tune the model
|
||||
trainer.finetune(output_path="trained_models")
|
|
@ -0,0 +1,17 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "aiunn"
|
||||
version = "0.1.1"
|
||||
description = "Finetuner for image upscaling using AIIA"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {file = "LICENSE"}
|
||||
authors = [
|
||||
{name = "Falko Habel", email = "falko.habel@gmx.de"},
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://gitea.fabelous.app/Machine-Learning/aiuNN"
|
|
@ -0,0 +1,5 @@
|
|||
torch
|
||||
aiia
|
||||
pillow
|
||||
torchvision
|
||||
sklearn
|
|
@ -0,0 +1,14 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="aiunn",
|
||||
version="0.1.1",
|
||||
packages=find_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
install_requires=[
|
||||
line.strip()
|
||||
for line in open("requirements.txt")
|
||||
if line.strip() and not line.startswith("#")
|
||||
],
|
||||
python_requires=">=3.10",
|
||||
)
|
|
@ -0,0 +1,5 @@
|
|||
from .finetune.trainer import aiuNNTrainer
|
||||
from .upsampler.aiunn import aiuNN
|
||||
from .upsampler.config import aiuNNConfig
|
||||
|
||||
__version__ = "0.1.1"
|
|
@ -0,0 +1,3 @@
|
|||
from .trainer import aiuNNTrainer
|
||||
|
||||
__all__ = ["aiuNNTrainer" ]
|
|
@ -0,0 +1,289 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import os
|
||||
import csv
|
||||
from torch.amp import autocast, GradScaler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import gc
|
||||
import time
|
||||
import shutil
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience=3, min_delta=0.001):
|
||||
# Number of epochs with no significant improvement before stopping
|
||||
# Minimum change in loss required to count as an improvement
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.best_loss = float('inf')
|
||||
self.counter = 0
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, epoch_loss):
|
||||
if epoch_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = epoch_loss
|
||||
self.counter = 0
|
||||
return True # Improved
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
return False # Not improved
|
||||
|
||||
class aiuNNTrainer:
|
||||
def __init__(self, upscaler_model, dataset_class=None):
|
||||
"""
|
||||
Initialize the upscaler trainer
|
||||
|
||||
Args:
|
||||
upscaler_model: The model to fine-tune
|
||||
dataset_class: The dataset class to use for loading data (optional)
|
||||
"""
|
||||
self.model = upscaler_model
|
||||
self.dataset_class = dataset_class
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model = self.model.to(self.device, memory_format=torch.channels_last)
|
||||
self.criterion = nn.MSELoss()
|
||||
self.optimizer = None
|
||||
self.scaler = GradScaler()
|
||||
self.best_loss = float('inf')
|
||||
self.use_checkpointing = True
|
||||
self.data_loader = None
|
||||
self.validation_loader = None
|
||||
self.log_dir = None
|
||||
|
||||
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
||||
"""
|
||||
Load data using either a custom dataset instance or the dataset class provided at initialization
|
||||
|
||||
Args:
|
||||
dataset_params (dict/list): Parameters to pass to the dataset class constructor
|
||||
batch_size (int): Batch size for training
|
||||
validation_split (float): Proportion of data to use for validation
|
||||
custom_train_dataset: A pre-instantiated dataset to use for training (optional)
|
||||
custom_val_dataset: A pre-instantiated dataset to use for validation (optional)
|
||||
"""
|
||||
# If custom datasets are provided directly, use them
|
||||
if custom_train_dataset is not None:
|
||||
train_dataset = custom_train_dataset
|
||||
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
|
||||
else:
|
||||
# Otherwise instantiate dataset using the class and parameters
|
||||
if self.dataset_class is None:
|
||||
raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.")
|
||||
|
||||
# Create dataset instance
|
||||
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
|
||||
|
||||
# Split into train and validation sets
|
||||
dataset_size = len(dataset)
|
||||
val_size = int(validation_split * dataset_size)
|
||||
train_size = dataset_size - val_size
|
||||
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(
|
||||
dataset, [train_size, val_size]
|
||||
)
|
||||
|
||||
# Create data loaders
|
||||
self.data_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
if val_dataset is not None:
|
||||
self.validation_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=True
|
||||
)
|
||||
print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
|
||||
else:
|
||||
self.validation_loader = None
|
||||
print(f"Loaded {len(train_dataset)} training samples (no validation set)")
|
||||
|
||||
return self.data_loader, self.validation_loader
|
||||
|
||||
def _setup_logging(self, output_path):
|
||||
"""Set up directory structure for logging and model checkpoints"""
|
||||
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
|
||||
# Create checkpoint directory
|
||||
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints")
|
||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Set up CSV logging
|
||||
self.csv_path = os.path.join(self.log_dir, 'training_log.csv')
|
||||
with open(self.csv_path, mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
if self.validation_loader:
|
||||
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved'])
|
||||
else:
|
||||
writer.writerow(['Epoch', 'Train Loss', 'Improved'])
|
||||
|
||||
def _evaluate(self):
|
||||
"""Evaluate the model on validation data"""
|
||||
if self.validation_loader is None:
|
||||
return 0.0
|
||||
|
||||
self.model.eval()
|
||||
val_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
|
||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||
high_res = high_res.to(self.device, non_blocking=True)
|
||||
|
||||
with autocast(device_type=self.device.type):
|
||||
outputs = self.model(low_res)
|
||||
loss = self.criterion(outputs, high_res)
|
||||
|
||||
val_loss += loss.item()
|
||||
|
||||
del low_res, high_res, outputs, loss
|
||||
|
||||
self.model.train()
|
||||
return val_loss
|
||||
|
||||
def _save_checkpoint(self, epoch, is_best=False):
|
||||
"""Save model checkpoint"""
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt")
|
||||
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||
|
||||
# Save the model checkpoint
|
||||
self.model.save(checkpoint_path)
|
||||
|
||||
# If this is the best model so far, copy it to best_model
|
||||
if is_best:
|
||||
if os.path.exists(best_model_path):
|
||||
shutil.rmtree(best_model_path)
|
||||
self.model.save(best_model_path)
|
||||
print(f"Saved new best model with loss: {self.best_loss:.6f}")
|
||||
|
||||
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
||||
"""
|
||||
Finetune the upscaler model
|
||||
|
||||
Args:
|
||||
output_path (str): Directory to save models and logs
|
||||
epochs (int): Maximum number of training epochs
|
||||
lr (float): Learning rate
|
||||
patience (int): Early stopping patience
|
||||
min_delta (float): Minimum improvement for early stopping
|
||||
"""
|
||||
# Check if data is loaded
|
||||
if self.data_loader is None:
|
||||
raise ValueError("Data not loaded. Call load_data first.")
|
||||
|
||||
# Setup optimizer
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||
|
||||
# Set up logging
|
||||
self._setup_logging(output_path)
|
||||
|
||||
# Setup early stopping
|
||||
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
||||
|
||||
# Training loop
|
||||
self.model.train()
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Training phase
|
||||
epoch_loss = 0.0
|
||||
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
|
||||
|
||||
for low_res, high_res in progress_bar:
|
||||
# Move data to GPU with channels_last format where possible
|
||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||
high_res = high_res.to(self.device, non_blocking=True)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with autocast(device_type=self.device.type):
|
||||
if self.use_checkpointing:
|
||||
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
|
||||
low_res.requires_grad_()
|
||||
outputs = checkpoint(self.model, low_res)
|
||||
else:
|
||||
outputs = self.model(low_res)
|
||||
loss = self.criterion(outputs, high_res)
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
progress_bar.set_postfix({'loss': loss.item()})
|
||||
|
||||
# Optionally delete variables to free memory
|
||||
del low_res, high_res, outputs, loss
|
||||
|
||||
# Calculate average epoch loss
|
||||
avg_train_loss = epoch_loss / len(self.data_loader)
|
||||
|
||||
# Validation phase (if validation loader exists)
|
||||
if self.validation_loader:
|
||||
val_loss = self._evaluate() / len(self.validation_loader)
|
||||
is_improved = val_loss < self.best_loss
|
||||
if is_improved:
|
||||
self.best_loss = val_loss
|
||||
|
||||
# Log results
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
||||
with open(self.csv_path, mode='a', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"])
|
||||
else:
|
||||
# If no validation, use training loss for improvement tracking
|
||||
is_improved = avg_train_loss < self.best_loss
|
||||
if is_improved:
|
||||
self.best_loss = avg_train_loss
|
||||
|
||||
# Log results
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
||||
with open(self.csv_path, mode='a', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"])
|
||||
|
||||
# Save checkpoint
|
||||
self._save_checkpoint(epoch + 1, is_best=is_improved)
|
||||
|
||||
# Perform garbage collection and clear GPU cache after each epoch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Check early stopping
|
||||
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
|
||||
print(f"Early stopping triggered at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
return self.best_loss
|
||||
|
||||
def save(self, output_path=None):
|
||||
"""
|
||||
Save the best model to the specified path
|
||||
|
||||
Args:
|
||||
output_path (str, optional): Path to save the model. If None, uses the best model from training.
|
||||
"""
|
||||
if output_path is None and self.log_dir is not None:
|
||||
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||
if os.path.exists(best_model_path):
|
||||
print(f"Best model already saved at {best_model_path}")
|
||||
return best_model_path
|
||||
else:
|
||||
output_path = os.path.join(self.log_dir, "final_model")
|
||||
|
||||
if output_path is None:
|
||||
raise ValueError("No output path specified and no training has been done yet.")
|
||||
|
||||
self.model.save(output_path)
|
||||
print(f"Model saved to {output_path}")
|
||||
return output_path
|
|
@ -0,0 +1,96 @@
|
|||
import torch
|
||||
from albumentations import Compose, Normalize
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import io
|
||||
from torch import nn
|
||||
from aiia import AIIABase
|
||||
|
||||
|
||||
class Upscaler(nn.Module):
|
||||
"""
|
||||
Transforms the base model's final feature map using a transposed convolution.
|
||||
The base model produces a feature map of size 512x512.
|
||||
This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features
|
||||
to the output channels using a single ConvTranspose2d layer.
|
||||
"""
|
||||
def __init__(self, base_model: AIIABase):
|
||||
super(Upscaler, self).__init__()
|
||||
self.base_model = base_model
|
||||
# Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer.
|
||||
self.last_transform = nn.ConvTranspose2d(
|
||||
in_channels=base_model.config.hidden_size,
|
||||
out_channels=base_model.config.num_channels,
|
||||
kernel_size=base_model.config.kernel_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
features = self.base_model(x)
|
||||
return self.last_transform(features)
|
||||
|
||||
|
||||
class ImageUpscaler:
|
||||
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
self.device = torch.device(device)
|
||||
self.model = self.load_model(model_path)
|
||||
self.model.eval() # Set to evaluation mode
|
||||
|
||||
# Define preprocessing transformations
|
||||
self.preprocess = Compose([
|
||||
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
ToTensorV2()
|
||||
])
|
||||
|
||||
def load_model(self, model_path: str):
|
||||
"""
|
||||
Load the trained model from the specified path.
|
||||
"""
|
||||
base_model = AIIABase.load(model_path) # Load base model
|
||||
model = Upscaler(base_model) # Wrap with Upscaler
|
||||
return model.to(self.device)
|
||||
|
||||
def preprocess_image(self, image: Image.Image):
|
||||
"""
|
||||
Preprocess input image for inference.
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL.Image.Image object")
|
||||
|
||||
# Convert to numpy array and apply preprocessing
|
||||
image_array = np.array(image)
|
||||
augmented = self.preprocess(image=image_array)
|
||||
|
||||
# Add batch dimension and move to device
|
||||
return augmented['image'].unsqueeze(0).to(self.device)
|
||||
|
||||
def postprocess_image(self, output_tensor: torch.Tensor):
|
||||
"""
|
||||
Convert output tensor back to an image.
|
||||
"""
|
||||
output_tensor = output_tensor.squeeze(0).cpu() # Remove batch dimension
|
||||
output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy() * 255
|
||||
output_array = output_array.transpose(1, 2, 0).astype(np.uint8) # CHW -> HWC
|
||||
return Image.fromarray(output_array)
|
||||
|
||||
def upscale_image(self, input_image_path: str):
|
||||
"""
|
||||
Perform upscaling on an input image.
|
||||
"""
|
||||
input_image = Image.open(input_image_path).convert('RGB') # Ensure RGB format
|
||||
preprocessed_image = self.preprocess_image(input_image)
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(device_type="cuda"):
|
||||
output_tensor = self.model(preprocessed_image)
|
||||
|
||||
return self.postprocess_image(output_tensor)
|
||||
|
||||
|
||||
# Example usage:
|
||||
upscaler = ImageUpscaler(model_path="/root/vision/aiuNN/best_model")
|
||||
upscaled_image = upscaler.upscale_image("/root/vision/aiuNN/input.jpg")
|
||||
upscaled_image.save("upscaled_image.jpg")
|
|
@ -0,0 +1,5 @@
|
|||
from .aiunn import aiuNN
|
||||
from .config import aiuNNConfig
|
||||
|
||||
|
||||
__all__ = ["aiuNN", "aiuNNConfig"]
|
|
@ -0,0 +1,82 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
from aiia import AIIA, AIIAConfig, AIIABase
|
||||
from .config import aiuNNConfig
|
||||
import warnings
|
||||
|
||||
|
||||
class aiuNN(AIIA):
|
||||
def __init__(self, base_model: AIIABase):
|
||||
super().__init__(base_model.config)
|
||||
self.base_model = base_model
|
||||
|
||||
# Pass the unified base configuration using the new parameter.
|
||||
self.config = aiuNNConfig(base_config=base_model.config)
|
||||
|
||||
self.upsample = nn.Upsample(
|
||||
scale_factor=self.config.upsample_scale,
|
||||
mode=self.config.upsample_mode,
|
||||
align_corners=self.config.upsample_align_corners
|
||||
)
|
||||
# Conversion layer: change from hidden size channels to 3 channels.
|
||||
self.to_rgb = nn.Conv2d(
|
||||
in_channels=self.base_model.config.hidden_size,
|
||||
out_channels=3,
|
||||
kernel_size=1
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base_model(x)
|
||||
x = self.upsample(x)
|
||||
x = self.to_rgb(x) # Ensures output has 3 channels.
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def load(cls, path, precision: str = None):
|
||||
# Load the configuration from disk.
|
||||
config = AIIAConfig.load(path)
|
||||
# Reconstruct the base model from the loaded configuration.
|
||||
base_model = AIIABase(config)
|
||||
# Instantiate the Upsampler using the proper base model.
|
||||
upsampler = cls(base_model)
|
||||
|
||||
# Load state dict and handle precision conversion if needed.
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
state_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||
if precision is not None:
|
||||
if precision.lower() == 'fp16':
|
||||
dtype = torch.float16
|
||||
elif precision.lower() == 'bf16':
|
||||
if device == 'cuda' and not torch.cuda.is_bf16_supported():
|
||||
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
||||
|
||||
for key, param in state_dict.items():
|
||||
if torch.is_tensor(param):
|
||||
state_dict[key] = param.to(dtype)
|
||||
upsampler.load_state_dict(state_dict)
|
||||
return upsampler
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from aiia import AIIABase, AIIAConfig
|
||||
# Create a configuration and build a base model.
|
||||
config = AIIAConfig()
|
||||
base_model = AIIABase(config)
|
||||
# Instantiate Upsampler from the base model (works correctly).
|
||||
upsampler = aiuNN(base_model)
|
||||
|
||||
# Save the model (both configuration and weights).
|
||||
upsampler.save("hehe")
|
||||
|
||||
# Now load using the overridden load method; this will load the complete model.
|
||||
upsampler_loaded = aiuNN.load("hehe", precision="bf16")
|
||||
print("Updated configuration:", upsampler_loaded.config.__dict__)
|
|
@ -0,0 +1,50 @@
|
|||
from aiia import AIIAConfig
|
||||
|
||||
|
||||
class aiuNNConfig(AIIAConfig):
|
||||
def __init__(
|
||||
self,
|
||||
base_config=None,
|
||||
upsample_scale: int = 2,
|
||||
upsample_mode: str = 'bilinear',
|
||||
upsample_align_corners: bool = False,
|
||||
layers=None,
|
||||
**kwargs
|
||||
):
|
||||
# Start with a single configuration dictionary.
|
||||
config_data = {}
|
||||
if base_config is not None:
|
||||
# If base_config is an object with a to_dict method, use it.
|
||||
if hasattr(base_config, "to_dict"):
|
||||
config_data.update(base_config.to_dict())
|
||||
elif isinstance(base_config, dict):
|
||||
config_data.update(base_config)
|
||||
|
||||
# Update with any additional keyword arguments (if needed).
|
||||
config_data.update(kwargs)
|
||||
|
||||
# Initialize base AIIAConfig with a single merged configuration.
|
||||
super().__init__(**config_data)
|
||||
|
||||
# Upsampler-specific parameters.
|
||||
self.upsample_scale = upsample_scale
|
||||
self.upsample_mode = upsample_mode
|
||||
self.upsample_align_corners = upsample_align_corners
|
||||
|
||||
# Use layers from the argument or initialize an empty list.
|
||||
self.layers = layers if layers is not None else []
|
||||
|
||||
# Add the upsample layer details only once.
|
||||
self.add_upsample_layer()
|
||||
|
||||
def add_upsample_layer(self):
|
||||
upsample_layer = {
|
||||
'name': 'Upsample',
|
||||
'type': 'nn.Upsample',
|
||||
'scale_factor': self.upsample_scale,
|
||||
'mode': self.upsample_mode,
|
||||
'align_corners': self.upsample_align_corners
|
||||
}
|
||||
# Append the layer only if it isn’t already present.
|
||||
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
|
||||
self.layers.append(upsample_layer)
|
Loading…
Reference in New Issue