Compare commits
No commits in common. "fd924fa024315c6544cdf5b5414d337dff9ff2cb" and "df740140b1c18d3a945fd82063f0345d2a9ecd48" have entirely different histories.
fd924fa024
...
df740140b1
|
@ -34,4 +34,4 @@ jobs:
|
||||||
VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }}
|
VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
cd VectorLoader
|
cd VectorLoader
|
||||||
python -m src.run
|
python -m src.run --full
|
||||||
|
|
13
README.md
13
README.md
|
@ -26,22 +26,15 @@ pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git
|
||||||
Here's a basic example of how to use `aiuNN` for image upscaling:
|
Here's a basic example of how to use `aiuNN` for image upscaling:
|
||||||
|
|
||||||
```python src/main.py
|
```python src/main.py
|
||||||
from aiia import AIIABase, AIIAConfig
|
from aiia import AIIABase
|
||||||
from aiunn import aiuNN, aiuNNTrainer
|
from aiunn import aiuNN, aiuNNTrainer
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
# Create a configuration and build a base model.
|
|
||||||
config = AIIAConfig()
|
|
||||||
ai_config = aiuNNConfig()
|
|
||||||
|
|
||||||
base_model = AIIABase(config)
|
|
||||||
upscaler = aiuNN(config=ai_config)
|
|
||||||
|
|
||||||
# Load your base model and upscaler
|
# Load your base model and upscaler
|
||||||
pretrained_model_path = "path/to/aiia/model"
|
pretrained_model_path = "path/to/aiia/model"
|
||||||
base_model = AIIABase.from_pretrained(pretrained_model_path)
|
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||||
upscaler.load_base_model(base_model)
|
upscaler = aiuNN(base_model)
|
||||||
|
|
||||||
# Create trainer with your dataset class
|
# Create trainer with your dataset class
|
||||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="aiunn",
|
name="aiunn",
|
||||||
version="0.2.1",
|
version="0.1.2",
|
||||||
packages=find_packages(where="src"),
|
packages=find_packages(where="src"),
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
|
|
@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
|
||||||
from .upsampler.config import aiuNNConfig
|
from .upsampler.config import aiuNNConfig
|
||||||
from .inference.inference import aiuNNInference
|
from .inference.inference import aiuNNInference
|
||||||
|
|
||||||
__version__ = "0.2.1"
|
__version__ = "0.1.2"
|
|
@ -10,7 +10,6 @@ from torch.utils.checkpoint import checkpoint
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
|
@ -51,16 +50,10 @@ class aiuNNTrainer:
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.scaler = GradScaler()
|
self.scaler = GradScaler()
|
||||||
self.best_loss = float('inf')
|
self.best_loss = float('inf')
|
||||||
self.csv_path = None
|
self.use_checkpointing = True
|
||||||
self.checkpoint_dir = None
|
|
||||||
self.data_loader = None
|
self.data_loader = None
|
||||||
self.validation_loader = None
|
self.validation_loader = None
|
||||||
self.last_checkpoint_time = time.time()
|
self.log_dir = None
|
||||||
self.checkpoint_interval = 2 * 60 * 60 # 2 hours
|
|
||||||
self.last_22_date = None
|
|
||||||
self.recent_checkpoints = []
|
|
||||||
self.current_epoch = 0
|
|
||||||
|
|
||||||
|
|
||||||
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
||||||
"""
|
"""
|
||||||
|
@ -117,19 +110,23 @@ class aiuNNTrainer:
|
||||||
return self.data_loader, self.validation_loader
|
return self.data_loader, self.validation_loader
|
||||||
|
|
||||||
def _setup_logging(self, output_path):
|
def _setup_logging(self, output_path):
|
||||||
"""Set up basic logging and checkpoint directory"""
|
"""Set up directory structure for logging and model checkpoints"""
|
||||||
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
|
||||||
|
os.makedirs(self.log_dir, exist_ok=True)
|
||||||
|
|
||||||
# Create checkpoint directory
|
# Create checkpoint directory
|
||||||
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
|
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints")
|
||||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
# Set up CSV logging
|
# Set up CSV logging
|
||||||
self.csv_path = os.path.join(output_path, 'training_log.csv')
|
self.csv_path = os.path.join(self.log_dir, 'training_log.csv')
|
||||||
with open(self.csv_path, mode='w', newline='') as file:
|
with open(self.csv_path, mode='w', newline='') as file:
|
||||||
writer = csv.writer(file)
|
writer = csv.writer(file)
|
||||||
if self.validation_loader:
|
if self.validation_loader:
|
||||||
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
|
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved'])
|
||||||
else:
|
else:
|
||||||
writer.writerow(['Epoch', 'Train Loss'])
|
writer.writerow(['Epoch', 'Train Loss', 'Improved'])
|
||||||
|
|
||||||
def _evaluate(self):
|
def _evaluate(self):
|
||||||
"""Evaluate the model on validation data"""
|
"""Evaluate the model on validation data"""
|
||||||
|
@ -155,99 +152,63 @@ class aiuNNTrainer:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
return val_loss
|
return val_loss
|
||||||
|
|
||||||
def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False):
|
def _save_checkpoint(self, epoch, is_best=False):
|
||||||
"""Save checkpoint with support for regular, best, and 22:00 saves"""
|
"""Save model checkpoint"""
|
||||||
if is_22:
|
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt")
|
||||||
today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date()
|
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||||
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
|
||||||
else:
|
|
||||||
checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt"
|
|
||||||
|
|
||||||
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
|
# Save the model checkpoint
|
||||||
|
self.model.save(checkpoint_path)
|
||||||
|
|
||||||
checkpoint_data = {
|
# If this is the best model so far, copy it to best_model
|
||||||
'epoch': epoch,
|
|
||||||
'batch': batch_count,
|
|
||||||
'model_state_dict': self.model.state_dict(),
|
|
||||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
||||||
'best_loss': self.best_loss,
|
|
||||||
'scaler_state_dict': self.scaler.state_dict()
|
|
||||||
}
|
|
||||||
|
|
||||||
torch.save(checkpoint_data, checkpoint_path)
|
|
||||||
|
|
||||||
# Save best model separately
|
|
||||||
if is_best:
|
if is_best:
|
||||||
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
|
if os.path.exists(best_model_path):
|
||||||
self.model.save_pretrained(best_model_path)
|
shutil.rmtree(best_model_path)
|
||||||
|
self.model.save(best_model_path)
|
||||||
return checkpoint_path
|
print(f"Saved new best model with loss: {self.best_loss:.6f}")
|
||||||
|
|
||||||
def _handle_checkpoints(self, epoch, batch_count, is_improved):
|
|
||||||
"""Handle periodic and 22:00 checkpoint saving"""
|
|
||||||
current_time = time.time()
|
|
||||||
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2)))
|
|
||||||
|
|
||||||
# Regular interval checkpoint
|
|
||||||
if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
|
|
||||||
self._save_checkpoint(epoch, batch_count, is_improved)
|
|
||||||
self.last_checkpoint_time = current_time
|
|
||||||
|
|
||||||
# Special 22:00 checkpoint
|
|
||||||
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
|
|
||||||
if is_22_oclock and self.last_22_date != current_dt.date():
|
|
||||||
self._save_checkpoint(epoch, batch_count, is_improved, is_22=True)
|
|
||||||
self.last_22_date = current_dt.date()
|
|
||||||
|
|
||||||
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
||||||
"""Finetune the upscaler model"""
|
"""
|
||||||
|
Finetune the upscaler model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path (str): Directory to save models and logs
|
||||||
|
epochs (int): Maximum number of training epochs
|
||||||
|
lr (float): Learning rate
|
||||||
|
patience (int): Early stopping patience
|
||||||
|
min_delta (float): Minimum improvement for early stopping
|
||||||
|
"""
|
||||||
|
# Check if data is loaded
|
||||||
if self.data_loader is None:
|
if self.data_loader is None:
|
||||||
raise ValueError("Data not loaded. Call load_data first.")
|
raise ValueError("Data not loaded. Call load_data first.")
|
||||||
|
|
||||||
# Setup optimizer and directories
|
# Setup optimizer
|
||||||
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||||
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
|
|
||||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Setup CSV logging
|
# Set up logging
|
||||||
self.csv_path = os.path.join(output_path, 'training_log.csv')
|
self._setup_logging(output_path)
|
||||||
with open(self.csv_path, mode='w', newline='') as file:
|
|
||||||
writer = csv.writer(file)
|
|
||||||
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
|
|
||||||
writer.writerow(header)
|
|
||||||
|
|
||||||
# Load existing checkpoint if available
|
|
||||||
checkpoint_info = self.load_checkpoint()
|
|
||||||
start_epoch = checkpoint_info[0] if checkpoint_info else 0
|
|
||||||
start_batch = checkpoint_info[1] if checkpoint_info else 0
|
|
||||||
|
|
||||||
# Setup early stopping
|
# Setup early stopping
|
||||||
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
||||||
self.best_loss = float('inf')
|
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
self.model.train()
|
self.model.train()
|
||||||
for epoch in range(start_epoch, epochs):
|
|
||||||
self.current_epoch = epoch
|
for epoch in range(epochs):
|
||||||
|
# Training phase
|
||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
|
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
|
||||||
|
|
||||||
train_batches = list(enumerate(self.data_loader))
|
for low_res, high_res in progress_bar:
|
||||||
start_idx = start_batch if epoch == start_epoch else 0
|
# Move data to GPU with channels_last format where possible
|
||||||
|
|
||||||
progress_bar = tqdm(train_batches[start_idx:],
|
|
||||||
initial=start_idx,
|
|
||||||
total=len(train_batches),
|
|
||||||
desc=f"Epoch {epoch + 1}/{epochs}")
|
|
||||||
|
|
||||||
for batch_idx, (low_res, high_res) in progress_bar:
|
|
||||||
# Training step
|
|
||||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||||
high_res = high_res.to(self.device, non_blocking=True)
|
high_res = high_res.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
with autocast(device_type=self.device.type):
|
with autocast(device_type=self.device.type):
|
||||||
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
if self.use_checkpointing:
|
||||||
|
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
|
||||||
low_res.requires_grad_()
|
low_res.requires_grad_()
|
||||||
outputs = checkpoint(self.model, low_res)
|
outputs = checkpoint(self.model, low_res)
|
||||||
else:
|
else:
|
||||||
|
@ -261,109 +222,69 @@ class aiuNNTrainer:
|
||||||
epoch_loss += loss.item()
|
epoch_loss += loss.item()
|
||||||
progress_bar.set_postfix({'loss': loss.item()})
|
progress_bar.set_postfix({'loss': loss.item()})
|
||||||
|
|
||||||
# Handle checkpoints
|
# Optionally delete variables to free memory
|
||||||
self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss)
|
|
||||||
|
|
||||||
del low_res, high_res, outputs, loss
|
del low_res, high_res, outputs, loss
|
||||||
|
|
||||||
# End of epoch processing
|
# Calculate average epoch loss
|
||||||
avg_train_loss = epoch_loss / len(self.data_loader)
|
avg_train_loss = epoch_loss / len(self.data_loader)
|
||||||
|
|
||||||
# Validation phase
|
# Validation phase (if validation loader exists)
|
||||||
if self.validation_loader:
|
if self.validation_loader:
|
||||||
val_loss = self._evaluate() / len(self.validation_loader)
|
val_loss = self._evaluate() / len(self.validation_loader)
|
||||||
is_improved = val_loss < self.best_loss
|
is_improved = val_loss < self.best_loss
|
||||||
if is_improved:
|
if is_improved:
|
||||||
self.best_loss = val_loss
|
self.best_loss = val_loss
|
||||||
|
|
||||||
# Log to CSV
|
# Log results
|
||||||
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
||||||
with open(self.csv_path, mode='a', newline='') as file:
|
with open(self.csv_path, mode='a', newline='') as file:
|
||||||
writer = csv.writer(file)
|
writer = csv.writer(file)
|
||||||
writer.writerow([epoch + 1, avg_train_loss, val_loss])
|
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"])
|
||||||
|
|
||||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
|
||||||
else:
|
else:
|
||||||
|
# If no validation, use training loss for improvement tracking
|
||||||
is_improved = avg_train_loss < self.best_loss
|
is_improved = avg_train_loss < self.best_loss
|
||||||
if is_improved:
|
if is_improved:
|
||||||
self.best_loss = avg_train_loss
|
self.best_loss = avg_train_loss
|
||||||
|
|
||||||
# Log to CSV
|
# Log results
|
||||||
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
||||||
with open(self.csv_path, mode='a', newline='') as file:
|
with open(self.csv_path, mode='a', newline='') as file:
|
||||||
writer = csv.writer(file)
|
writer = csv.writer(file)
|
||||||
writer.writerow([epoch + 1, avg_train_loss])
|
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"])
|
||||||
|
|
||||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
# Save checkpoint
|
||||||
|
self._save_checkpoint(epoch + 1, is_best=is_improved)
|
||||||
|
|
||||||
# Save best model if improved
|
# Perform garbage collection and clear GPU cache after each epoch
|
||||||
if is_improved:
|
|
||||||
best_model_path = os.path.join(output_path, "best_model")
|
|
||||||
self.model.save_pretrained(best_model_path)
|
|
||||||
|
|
||||||
# Check early stopping
|
|
||||||
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
|
|
||||||
print(f"Early stopping triggered at epoch {epoch + 1}")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Check early stopping
|
||||||
|
early_stopping(val_loss if self.validation_loader else avg_train_loss)
|
||||||
|
if early_stopping.early_stop:
|
||||||
|
print(f"Early stopping triggered at epoch {epoch + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
return self.best_loss
|
return self.best_loss
|
||||||
|
|
||||||
def load_checkpoint(self, specific_checkpoint=None):
|
|
||||||
"""Enhanced checkpoint loading with specific checkpoint support"""
|
|
||||||
if specific_checkpoint:
|
|
||||||
checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint)
|
|
||||||
else:
|
|
||||||
checkpoint_files = [f for f in os.listdir(self.checkpoint_dir)
|
|
||||||
if f.startswith("checkpoint_") and f.endswith(".pt")]
|
|
||||||
if not checkpoint_files:
|
|
||||||
return None
|
|
||||||
|
|
||||||
checkpoint_files.sort(key=lambda x: os.path.getmtime(
|
|
||||||
os.path.join(self.checkpoint_dir, x)), reverse=True)
|
|
||||||
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0])
|
|
||||||
|
|
||||||
if not os.path.exists(checkpoint_path):
|
|
||||||
return None
|
|
||||||
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
||||||
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
||||||
self.best_loss = checkpoint['best_loss']
|
|
||||||
|
|
||||||
print(f"Loaded checkpoint from {checkpoint_path}")
|
|
||||||
return checkpoint['epoch'], checkpoint['batch']
|
|
||||||
|
|
||||||
def save(self, output_path=None):
|
def save(self, output_path=None):
|
||||||
"""
|
"""
|
||||||
Save the best model to the specified path
|
Save the best model to the specified path
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training.
|
output_path (str, optional): Path to save the model. If None, uses the best model from training.
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Path where the model was saved
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no output path is specified and no checkpoint directory exists
|
|
||||||
"""
|
"""
|
||||||
if output_path is None and self.checkpoint_dir is not None:
|
if output_path is None and self.log_dir is not None:
|
||||||
# First try to copy the best model if it exists
|
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||||
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
|
|
||||||
if os.path.exists(best_model_path):
|
if os.path.exists(best_model_path):
|
||||||
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
|
print(f"Best model already saved at {best_model_path}")
|
||||||
shutil.copytree(best_model_path, output_path, dirs_exist_ok=True)
|
return best_model_path
|
||||||
print(f"Copied best model to {output_path}")
|
|
||||||
return output_path
|
|
||||||
else:
|
else:
|
||||||
# If no best model exists, save current model state
|
output_path = os.path.join(self.log_dir, "final_model")
|
||||||
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
|
|
||||||
|
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
raise ValueError("No output path specified and no checkpoint directory exists from training.")
|
raise ValueError("No output path specified and no training has been done yet.")
|
||||||
|
|
||||||
self.model.save_pretrained(output_path)
|
self.model.save(output_path)
|
||||||
print(f"Model saved to {output_path}")
|
print(f"Model saved to {output_path}")
|
||||||
return output_path
|
return output_path
|
|
@ -12,13 +12,13 @@ class aiuNNInference:
|
||||||
Inference class for aiuNN upsampling model.
|
Inference class for aiuNN upsampling model.
|
||||||
Handles model loading, image upscaling, and output processing.
|
Handles model loading, image upscaling, and output processing.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model_path: str, device: Optional[str] = None):
|
def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Initialize the inference class by loading the aiuNN model.
|
Initialize the inference class by loading the aiuNN model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Path to the saved model directory
|
model_path: Path to the saved model directory
|
||||||
|
precision: Optional precision setting ('fp16', 'bf16', or None for default)
|
||||||
device: Optional device specification ('cuda', 'cpu', or None for auto-detection)
|
device: Optional device specification ('cuda', 'cpu', or None for auto-detection)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ class aiuNNInference:
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
# Load the model with specified precision
|
# Load the model with specified precision
|
||||||
self.model = aiuNN.from_pretrained(model_path)
|
self.model = aiuNN.load(model_path, precision=precision)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
|
@ -160,11 +160,54 @@ class aiuNNInference:
|
||||||
|
|
||||||
return binary_data
|
return binary_data
|
||||||
|
|
||||||
|
def process_batch(self,
|
||||||
|
images: List[Union[str, Image.Image]],
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
save_format: str = 'PNG',
|
||||||
|
return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]:
|
||||||
|
"""
|
||||||
|
Process multiple images in batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of input images (paths or PIL Images)
|
||||||
|
output_dir: Optional directory to save results
|
||||||
|
save_format: Format to use when saving images
|
||||||
|
return_binary: Whether to return binary data instead of PIL Images
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of processed images or binary data, or None if only saving
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
# Upscale the image
|
||||||
|
upscaled = self.upscale(img)
|
||||||
|
|
||||||
|
# Save if output directory is provided
|
||||||
|
if output_dir:
|
||||||
|
# Extract filename if input is a path
|
||||||
|
if isinstance(img, str):
|
||||||
|
filename = os.path.basename(img)
|
||||||
|
base, _ = os.path.splitext(filename)
|
||||||
|
else:
|
||||||
|
base = f"upscaled_{i}"
|
||||||
|
|
||||||
|
output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}")
|
||||||
|
self.save(upscaled, output_path, format=save_format)
|
||||||
|
|
||||||
|
# Add to results based on return type
|
||||||
|
if return_binary:
|
||||||
|
results.append(self.convert_to_binary(upscaled, format=save_format))
|
||||||
|
else:
|
||||||
|
results.append(upscaled)
|
||||||
|
|
||||||
|
return results if (not output_dir or return_binary or not save_format) else None
|
||||||
|
|
||||||
|
|
||||||
# Example usage (can be removed)
|
# Example usage (can be removed)
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Initialize inference with a model path
|
# Initialize inference with a model path
|
||||||
inferencer = aiuNNInference("path/to/model")
|
inferencer = aiuNNInference("path/to/model", precision="bf16")
|
||||||
|
|
||||||
# Upscale a single image
|
# Upscale a single image
|
||||||
upscaled_image = inferencer.upscale("input_image.jpg")
|
upscaled_image = inferencer.upscale("input_image.jpg")
|
||||||
|
@ -175,3 +218,9 @@ if __name__ == "__main__":
|
||||||
# Convert to binary
|
# Convert to binary
|
||||||
binary_data = inferencer.convert_to_binary(upscaled_image)
|
binary_data = inferencer.convert_to_binary(upscaled_image)
|
||||||
|
|
||||||
|
# Process a batch of images
|
||||||
|
inferencer.process_batch(
|
||||||
|
["image1.jpg", "image2.jpg"],
|
||||||
|
output_dir="output_folder",
|
||||||
|
save_format="PNG"
|
||||||
|
)
|
|
@ -2,16 +2,16 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import warnings
|
import warnings
|
||||||
from aiia.model.Model import AIIAConfig, AIIABase
|
from aiia.model.Model import AIIA, AIIAConfig, AIIABase
|
||||||
from transformers import PreTrainedModel
|
|
||||||
from .config import aiuNNConfig
|
from .config import aiuNNConfig
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class aiuNN(PreTrainedModel):
|
class aiuNN(AIIA):
|
||||||
config_class = aiuNNConfig
|
def __init__(self, base_model: AIIA, config:aiuNNConfig):
|
||||||
def __init__(self, config: aiuNNConfig):
|
super().__init__(base_model.config)
|
||||||
super().__init__(config)
|
self.base_model = base_model
|
||||||
|
|
||||||
# Pass the unified base configuration using the new parameter.
|
# Pass the unified base configuration using the new parameter.
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -26,18 +26,118 @@ class aiuNN(PreTrainedModel):
|
||||||
)
|
)
|
||||||
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
||||||
|
|
||||||
def load_base_model(self, base_model: PreTrainedModel):
|
|
||||||
self.base_model = base_model
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.base_model is None:
|
|
||||||
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
|
||||||
x = self.base_model(x) # Get base features
|
x = self.base_model(x) # Get base features
|
||||||
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
||||||
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path, precision: str = None, **kwargs):
|
||||||
|
"""
|
||||||
|
Load a aiuNN model from disk with automatic detection of base model type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Directory containing the stored configuration and model parameters.
|
||||||
|
precision (str, optional): Desired precision for the model's parameters.
|
||||||
|
**kwargs: Additional keyword arguments to override configuration parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of aiuNN with loaded weights.
|
||||||
|
"""
|
||||||
|
# Load the configuration
|
||||||
|
config = aiuNNConfig.load(path)
|
||||||
|
|
||||||
|
# Determine the device
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
# Load the state dictionary
|
||||||
|
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
|
||||||
|
|
||||||
|
# Import all model types
|
||||||
|
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
|
||||||
|
|
||||||
|
# Helper function to detect base class type from key patterns
|
||||||
|
def detect_base_class_type(keys_prefix):
|
||||||
|
if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()):
|
||||||
|
return AIIABaseShared
|
||||||
|
else:
|
||||||
|
return AIIABase
|
||||||
|
|
||||||
|
# Detect base model type
|
||||||
|
base_model = None
|
||||||
|
|
||||||
|
# Check for AIIAmoe with multiple experts
|
||||||
|
if any("base_model.experts" in key for key in state_dict.keys()):
|
||||||
|
# Count the number of experts
|
||||||
|
max_expert_idx = -1
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if "base_model.experts." in key:
|
||||||
|
try:
|
||||||
|
parts = key.split("base_model.experts.")[1].split(".")
|
||||||
|
expert_idx = int(parts[0])
|
||||||
|
max_expert_idx = max(max_expert_idx, expert_idx)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if max_expert_idx >= 0:
|
||||||
|
# Determine the type of base_cnn each expert is using
|
||||||
|
base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn")
|
||||||
|
|
||||||
|
# Create AIIAmoe with the detected expert count and base class
|
||||||
|
base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs)
|
||||||
|
|
||||||
|
# Check for AIIAchunked or AIIArecursive
|
||||||
|
elif any("base_model.chunked_cnn" in key for key in state_dict.keys()):
|
||||||
|
if any("recursion_depth" in key for key in state_dict.keys()):
|
||||||
|
# This is an AIIArecursive model
|
||||||
|
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
|
||||||
|
base_model = AIIArecursive(config, base_class=base_class, **kwargs)
|
||||||
|
else:
|
||||||
|
# This is an AIIAchunked model
|
||||||
|
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
|
||||||
|
base_model = AIIAchunked(config, base_class=base_class, **kwargs)
|
||||||
|
|
||||||
|
# Check for AIIAExpert
|
||||||
|
elif any("base_model.base_cnn" in key for key in state_dict.keys()):
|
||||||
|
# Determine which base class the expert is using
|
||||||
|
base_class = detect_base_class_type("base_model.base_cnn")
|
||||||
|
base_model = AIIAExpert(config, base_class=base_class, **kwargs)
|
||||||
|
|
||||||
|
# If none of the above, use AIIABase or AIIABaseShared directly
|
||||||
|
else:
|
||||||
|
base_class = detect_base_class_type("base_model")
|
||||||
|
base_model = base_class(config, **kwargs)
|
||||||
|
|
||||||
|
# Create the aiuNN model with the detected base model
|
||||||
|
model = cls(base_model, config=base_model.config)
|
||||||
|
|
||||||
|
# Handle precision conversion
|
||||||
|
dtype = None
|
||||||
|
if precision is not None:
|
||||||
|
if precision.lower() == 'fp16':
|
||||||
|
dtype = torch.float16
|
||||||
|
elif precision.lower() == 'bf16':
|
||||||
|
if device == 'cuda' and not torch.cuda.is_bf16_supported():
|
||||||
|
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
for key, param in state_dict.items():
|
||||||
|
if torch.is_tensor(param):
|
||||||
|
state_dict[key] = param.to(dtype)
|
||||||
|
|
||||||
|
# Load the state dict
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from aiia import AIIABase, AIIAConfig
|
from aiia import AIIABase, AIIAConfig
|
||||||
|
@ -46,11 +146,11 @@ if __name__ == "__main__":
|
||||||
ai_config = aiuNNConfig()
|
ai_config = aiuNNConfig()
|
||||||
base_model = AIIABase(config)
|
base_model = AIIABase(config)
|
||||||
# Instantiate Upsampler from the base model (works correctly).
|
# Instantiate Upsampler from the base model (works correctly).
|
||||||
upsampler = aiuNN(config=ai_config)
|
upsampler = aiuNN(base_model, config=ai_config)
|
||||||
upsampler.load_base_model(base_model)
|
|
||||||
# Save the model (both configuration and weights).
|
# Save the model (both configuration and weights).
|
||||||
upsampler.save_pretrained("aiunn")
|
upsampler.save("aiunn")
|
||||||
|
|
||||||
# Now load using the overridden load method; this will load the complete model.
|
# Now load using the overridden load method; this will load the complete model.
|
||||||
upsampler_loaded = aiuNN.from_pretrained("aiunn")
|
upsampler_loaded = aiuNN.load("aiunn", precision="bf16")
|
||||||
print("Updated configuration:", upsampler_loaded.config.__dict__)
|
print("Updated configuration:", upsampler_loaded.config.__dict__)
|
||||||
|
|
|
@ -21,8 +21,9 @@ def real_model(tmp_path):
|
||||||
base_model = AIIABase(config)
|
base_model = AIIABase(config)
|
||||||
|
|
||||||
# Make sure aiuNN is properly configured with all required attributes
|
# Make sure aiuNN is properly configured with all required attributes
|
||||||
upsampler = aiuNN(config=ai_config)
|
upsampler = aiuNN(base_model, config=ai_config)
|
||||||
upsampler.load_base_model(base_model)
|
# Ensure the upsample attribute is properly set if needed
|
||||||
|
# upsampler.upsample = ... # Add any necessary initialization
|
||||||
|
|
||||||
# Save the model and config to temporary directory
|
# Save the model and config to temporary directory
|
||||||
save_path = str(model_dir / "save")
|
save_path = str(model_dir / "save")
|
||||||
|
@ -39,10 +40,10 @@ def real_model(tmp_path):
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
# Save model
|
# Save model
|
||||||
upsampler.save_pretrained(save_path)
|
upsampler.save(save_path)
|
||||||
|
|
||||||
# Load model in inference mode
|
# Load model in inference mode
|
||||||
inference_model = aiuNNInference(model_path=save_path, device='cpu')
|
inference_model = aiuNNInference(model_path=save_path, precision='fp16', device='cpu')
|
||||||
return inference_model
|
return inference_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,3 +88,12 @@ def test_convert_to_binary(inference):
|
||||||
result = inference.convert_to_binary(test_image)
|
result = inference.convert_to_binary(test_image)
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
|
||||||
|
def test_process_batch(inference):
|
||||||
|
# Create test images
|
||||||
|
test_array = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
test_images = [Image.fromarray(test_array) for _ in range(2)]
|
||||||
|
|
||||||
|
results = inference.process_batch(test_images)
|
||||||
|
assert len(results) == 2
|
||||||
|
assert all(isinstance(img, Image.Image) for img in results)
|
|
@ -10,21 +10,39 @@ def test_save_and_load_model():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
ai_config = aiuNNConfig()
|
ai_config = aiuNNConfig()
|
||||||
base_model = AIIABase(config)
|
base_model = AIIABase(config)
|
||||||
upsampler = aiuNN(config=ai_config)
|
upsampler = aiuNN(base_model, config=ai_config)
|
||||||
upsampler.load_base_model(base_model)
|
|
||||||
# Save the model
|
# Save the model
|
||||||
save_path = os.path.join(tmpdirname, "model")
|
save_path = os.path.join(tmpdirname, "model")
|
||||||
upsampler.save_pretrained(save_path)
|
upsampler.save(save_path)
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
loaded_upsampler = aiuNN.from_pretrained(save_path)
|
loaded_upsampler = aiuNN.load(save_path)
|
||||||
|
|
||||||
# Verify that the loaded model is the same as the original model
|
# Verify that the loaded model is the same as the original model
|
||||||
assert isinstance(loaded_upsampler, aiuNN)
|
assert isinstance(loaded_upsampler, aiuNN)
|
||||||
assert loaded_upsampler.config.hidden_size == upsampler.config.hidden_size
|
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__
|
||||||
assert loaded_upsampler.config._activation_function == upsampler.config._activation_function
|
|
||||||
assert loaded_upsampler.config.architectures == upsampler.config.architectures
|
|
||||||
|
|
||||||
|
def test_save_and_load_model_with_precision():
|
||||||
|
# Create a temporary directory to save the model
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
# Create configurations and build a base model
|
||||||
|
config = AIIAConfig()
|
||||||
|
ai_config = aiuNNConfig()
|
||||||
|
base_model = AIIABase(config)
|
||||||
|
upsampler = aiuNN(base_model, config=ai_config)
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
save_path = os.path.join(tmpdirname, "model")
|
||||||
|
upsampler.save(save_path)
|
||||||
|
|
||||||
|
# Load the model with precision 'bf16'
|
||||||
|
loaded_upsampler = aiuNN.load(save_path, precision="bf16")
|
||||||
|
|
||||||
|
# Verify that the loaded model is the same as the original model
|
||||||
|
assert isinstance(loaded_upsampler, aiuNN)
|
||||||
|
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_save_and_load_model()
|
test_save_and_load_model()
|
||||||
|
test_save_and_load_model_with_precision()
|
Loading…
Reference in New Issue