Compare commits
No commits in common. "96e14b96741d33e75d69a1aa0d4dfaa47534cbb3" and "ef19e24f11623306ee803f06bb3d9dc94b733ae3" have entirely different histories.
96e14b9674
...
ef19e24f11
23
README.md
23
README.md
|
@ -26,22 +26,15 @@ pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git
|
|||
Here's a basic example of how to use `aiuNN` for image upscaling:
|
||||
|
||||
```python src/main.py
|
||||
from aiia import AIIABase, AIIAConfig
|
||||
from aiia import AIIABase
|
||||
from aiunn import aiuNN, aiuNNTrainer
|
||||
import pandas as pd
|
||||
from torchvision import transforms
|
||||
|
||||
# Create a configuration and build a base model.
|
||||
config = AIIAConfig()
|
||||
ai_config = aiuNNConfig()
|
||||
|
||||
base_model = AIIABase(config)
|
||||
upscaler = aiuNN(config=ai_config)
|
||||
|
||||
# Load your base model and upscaler
|
||||
pretrained_model_path = "path/to/aiia/model"
|
||||
base_model = AIIABase.from_pretrained(pretrained_model_path)
|
||||
upscaler.load_base_model(base_model)
|
||||
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||
upscaler = aiuNN(base_model)
|
||||
|
||||
# Create trainer with your dataset class
|
||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||
|
@ -112,19 +105,19 @@ class UpscaleDataset(Dataset):
|
|||
# Open image bytes with Pillow and convert to RGBA first
|
||||
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
|
||||
high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
|
||||
|
||||
|
||||
# Create a new RGB image with black background
|
||||
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
|
||||
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
|
||||
|
||||
|
||||
# Composite the original image over the black background
|
||||
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
|
||||
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
|
||||
|
||||
|
||||
# Now we have true 3-channel RGB images with transparent areas converted to black
|
||||
low_res = low_res_rgb
|
||||
high_res = high_res_rgb
|
||||
|
||||
|
||||
# If a transform is provided (e.g. conversion to Tensor), apply it
|
||||
if self.transform:
|
||||
low_res = self.transform(low_res)
|
||||
|
@ -134,4 +127,4 @@ class UpscaleDataset(Dataset):
|
|||
print(f"\nError at index {idx}: {str(e)}")
|
||||
self.failed_indices.add(idx)
|
||||
return self[(idx + 1) % len(self)]
|
||||
```
|
||||
```
|
||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||
|
||||
setup(
|
||||
name="aiunn",
|
||||
version="0.2.1",
|
||||
version="0.2.0",
|
||||
packages=find_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
install_requires=[
|
||||
|
|
|
@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
|
|||
from .upsampler.config import aiuNNConfig
|
||||
from .inference.inference import aiuNNInference
|
||||
|
||||
__version__ = "0.2.1"
|
||||
__version__ = "0.2.0"
|
|
@ -10,7 +10,6 @@ from torch.utils.checkpoint import checkpoint
|
|||
import gc
|
||||
import time
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
|
@ -51,16 +50,10 @@ class aiuNNTrainer:
|
|||
self.optimizer = None
|
||||
self.scaler = GradScaler()
|
||||
self.best_loss = float('inf')
|
||||
self.csv_path = None
|
||||
self.checkpoint_dir = None
|
||||
self.use_checkpointing = True
|
||||
self.data_loader = None
|
||||
self.validation_loader = None
|
||||
self.last_checkpoint_time = time.time()
|
||||
self.checkpoint_interval = 2 * 60 * 60 # 2 hours
|
||||
self.last_22_date = None
|
||||
self.recent_checkpoints = []
|
||||
self.current_epoch = 0
|
||||
|
||||
self.log_dir = None
|
||||
|
||||
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
||||
"""
|
||||
|
@ -117,19 +110,23 @@ class aiuNNTrainer:
|
|||
return self.data_loader, self.validation_loader
|
||||
|
||||
def _setup_logging(self, output_path):
|
||||
"""Set up basic logging and checkpoint directory"""
|
||||
"""Set up directory structure for logging and model checkpoints"""
|
||||
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
|
||||
# Create checkpoint directory
|
||||
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
|
||||
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints")
|
||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Set up CSV logging
|
||||
self.csv_path = os.path.join(output_path, 'training_log.csv')
|
||||
self.csv_path = os.path.join(self.log_dir, 'training_log.csv')
|
||||
with open(self.csv_path, mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
if self.validation_loader:
|
||||
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
|
||||
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved'])
|
||||
else:
|
||||
writer.writerow(['Epoch', 'Train Loss'])
|
||||
writer.writerow(['Epoch', 'Train Loss', 'Improved'])
|
||||
|
||||
def _evaluate(self):
|
||||
"""Evaluate the model on validation data"""
|
||||
|
@ -155,100 +152,64 @@ class aiuNNTrainer:
|
|||
self.model.train()
|
||||
return val_loss
|
||||
|
||||
def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False):
|
||||
"""Save checkpoint with support for regular, best, and 22:00 saves"""
|
||||
if is_22:
|
||||
today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date()
|
||||
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
||||
else:
|
||||
checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt"
|
||||
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
|
||||
def _save_checkpoint(self, epoch, is_best=False):
|
||||
"""Save model checkpoint"""
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt")
|
||||
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||
|
||||
checkpoint_data = {
|
||||
'epoch': epoch,
|
||||
'batch': batch_count,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'best_loss': self.best_loss,
|
||||
'scaler_state_dict': self.scaler.state_dict()
|
||||
}
|
||||
# Save the model checkpoint
|
||||
self.model.save(checkpoint_path)
|
||||
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
|
||||
# Save best model separately
|
||||
# If this is the best model so far, copy it to best_model
|
||||
if is_best:
|
||||
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
|
||||
self.model.save_pretrained(best_model_path)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
def _handle_checkpoints(self, epoch, batch_count, is_improved):
|
||||
"""Handle periodic and 22:00 checkpoint saving"""
|
||||
current_time = time.time()
|
||||
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2)))
|
||||
|
||||
# Regular interval checkpoint
|
||||
if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
|
||||
self._save_checkpoint(epoch, batch_count, is_improved)
|
||||
self.last_checkpoint_time = current_time
|
||||
|
||||
# Special 22:00 checkpoint
|
||||
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
|
||||
if is_22_oclock and self.last_22_date != current_dt.date():
|
||||
self._save_checkpoint(epoch, batch_count, is_improved, is_22=True)
|
||||
self.last_22_date = current_dt.date()
|
||||
|
||||
if os.path.exists(best_model_path):
|
||||
shutil.rmtree(best_model_path)
|
||||
self.model.save(best_model_path)
|
||||
print(f"Saved new best model with loss: {self.best_loss:.6f}")
|
||||
|
||||
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
||||
"""Finetune the upscaler model"""
|
||||
"""
|
||||
Finetune the upscaler model
|
||||
|
||||
Args:
|
||||
output_path (str): Directory to save models and logs
|
||||
epochs (int): Maximum number of training epochs
|
||||
lr (float): Learning rate
|
||||
patience (int): Early stopping patience
|
||||
min_delta (float): Minimum improvement for early stopping
|
||||
"""
|
||||
# Check if data is loaded
|
||||
if self.data_loader is None:
|
||||
raise ValueError("Data not loaded. Call load_data first.")
|
||||
|
||||
# Setup optimizer and directories
|
||||
# Setup optimizer
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
|
||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Setup CSV logging
|
||||
self.csv_path = os.path.join(output_path, 'training_log.csv')
|
||||
with open(self.csv_path, mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
|
||||
writer.writerow(header)
|
||||
|
||||
# Load existing checkpoint if available
|
||||
checkpoint_info = self.load_checkpoint()
|
||||
start_epoch = checkpoint_info[0] if checkpoint_info else 0
|
||||
start_batch = checkpoint_info[1] if checkpoint_info else 0
|
||||
# Set up logging
|
||||
self._setup_logging(output_path)
|
||||
|
||||
# Setup early stopping
|
||||
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
||||
self.best_loss = float('inf')
|
||||
|
||||
# Training loop
|
||||
self.model.train()
|
||||
for epoch in range(start_epoch, epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Training phase
|
||||
epoch_loss = 0.0
|
||||
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
|
||||
|
||||
train_batches = list(enumerate(self.data_loader))
|
||||
start_idx = start_batch if epoch == start_epoch else 0
|
||||
|
||||
progress_bar = tqdm(train_batches[start_idx:],
|
||||
initial=start_idx,
|
||||
total=len(train_batches),
|
||||
desc=f"Epoch {epoch + 1}/{epochs}")
|
||||
|
||||
for batch_idx, (low_res, high_res) in progress_bar:
|
||||
# Training step
|
||||
for low_res, high_res in progress_bar:
|
||||
# Move data to GPU with channels_last format where possible
|
||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||
high_res = high_res.to(self.device, non_blocking=True)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with autocast(device_type=self.device.type):
|
||||
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
||||
low_res.requires_grad_()
|
||||
if self.use_checkpointing:
|
||||
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
|
||||
low_res.requires_grad_()
|
||||
outputs = checkpoint(self.model, low_res)
|
||||
else:
|
||||
outputs = self.model(low_res)
|
||||
|
@ -261,109 +222,69 @@ class aiuNNTrainer:
|
|||
epoch_loss += loss.item()
|
||||
progress_bar.set_postfix({'loss': loss.item()})
|
||||
|
||||
# Handle checkpoints
|
||||
self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss)
|
||||
|
||||
# Optionally delete variables to free memory
|
||||
del low_res, high_res, outputs, loss
|
||||
|
||||
# End of epoch processing
|
||||
# Calculate average epoch loss
|
||||
avg_train_loss = epoch_loss / len(self.data_loader)
|
||||
|
||||
# Validation phase
|
||||
# Validation phase (if validation loader exists)
|
||||
if self.validation_loader:
|
||||
val_loss = self._evaluate() / len(self.validation_loader)
|
||||
is_improved = val_loss < self.best_loss
|
||||
if is_improved:
|
||||
self.best_loss = val_loss
|
||||
|
||||
# Log to CSV
|
||||
# Log results
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
||||
with open(self.csv_path, mode='a', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow([epoch + 1, avg_train_loss, val_loss])
|
||||
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
||||
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"])
|
||||
else:
|
||||
# If no validation, use training loss for improvement tracking
|
||||
is_improved = avg_train_loss < self.best_loss
|
||||
if is_improved:
|
||||
self.best_loss = avg_train_loss
|
||||
|
||||
# Log to CSV
|
||||
# Log results
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
||||
with open(self.csv_path, mode='a', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow([epoch + 1, avg_train_loss])
|
||||
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
||||
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"])
|
||||
|
||||
# Save best model if improved
|
||||
if is_improved:
|
||||
best_model_path = os.path.join(output_path, "best_model")
|
||||
self.model.save_pretrained(best_model_path)
|
||||
# Save checkpoint
|
||||
self._save_checkpoint(epoch + 1, is_best=is_improved)
|
||||
|
||||
# Check early stopping
|
||||
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
|
||||
print(f"Early stopping triggered at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Cleanup
|
||||
# Perform garbage collection and clear GPU cache after each epoch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Check early stopping
|
||||
early_stopping(val_loss if self.validation_loader else avg_train_loss)
|
||||
if early_stopping.early_stop:
|
||||
print(f"Early stopping triggered at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
return self.best_loss
|
||||
|
||||
def load_checkpoint(self, specific_checkpoint=None):
|
||||
"""Enhanced checkpoint loading with specific checkpoint support"""
|
||||
if specific_checkpoint:
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint)
|
||||
else:
|
||||
checkpoint_files = [f for f in os.listdir(self.checkpoint_dir)
|
||||
if f.startswith("checkpoint_") and f.endswith(".pt")]
|
||||
if not checkpoint_files:
|
||||
return None
|
||||
|
||||
checkpoint_files.sort(key=lambda x: os.path.getmtime(
|
||||
os.path.join(self.checkpoint_dir, x)), reverse=True)
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0])
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
return None
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
||||
self.best_loss = checkpoint['best_loss']
|
||||
|
||||
print(f"Loaded checkpoint from {checkpoint_path}")
|
||||
return checkpoint['epoch'], checkpoint['batch']
|
||||
|
||||
|
||||
def save(self, output_path=None):
|
||||
"""
|
||||
Save the best model to the specified path
|
||||
|
||||
Args:
|
||||
output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training.
|
||||
|
||||
Returns:
|
||||
str: Path where the model was saved
|
||||
|
||||
Raises:
|
||||
ValueError: If no output path is specified and no checkpoint directory exists
|
||||
output_path (str, optional): Path to save the model. If None, uses the best model from training.
|
||||
"""
|
||||
if output_path is None and self.checkpoint_dir is not None:
|
||||
# First try to copy the best model if it exists
|
||||
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
|
||||
if output_path is None and self.log_dir is not None:
|
||||
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||
if os.path.exists(best_model_path):
|
||||
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
|
||||
shutil.copytree(best_model_path, output_path, dirs_exist_ok=True)
|
||||
print(f"Copied best model to {output_path}")
|
||||
return output_path
|
||||
print(f"Best model already saved at {best_model_path}")
|
||||
return best_model_path
|
||||
else:
|
||||
# If no best model exists, save current model state
|
||||
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
|
||||
output_path = os.path.join(self.log_dir, "final_model")
|
||||
|
||||
if output_path is None:
|
||||
raise ValueError("No output path specified and no checkpoint directory exists from training.")
|
||||
raise ValueError("No output path specified and no training has been done yet.")
|
||||
|
||||
self.model.save_pretrained(output_path)
|
||||
self.model.save(output_path)
|
||||
print(f"Model saved to {output_path}")
|
||||
return output_path
|
Loading…
Reference in New Issue