develop #15

Merged
Fabel merged 11 commits from develop into main 2025-04-20 20:50:50 +00:00
9 changed files with 202 additions and 293 deletions

View File

@ -34,4 +34,4 @@ jobs:
VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }} VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }}
run: | run: |
cd VectorLoader cd VectorLoader
python -m src.run --full python -m src.run

View File

@ -26,15 +26,22 @@ pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git
Here's a basic example of how to use `aiuNN` for image upscaling: Here's a basic example of how to use `aiuNN` for image upscaling:
```python src/main.py ```python src/main.py
from aiia import AIIABase from aiia import AIIABase, AIIAConfig
from aiunn import aiuNN, aiuNNTrainer from aiunn import aiuNN, aiuNNTrainer
import pandas as pd import pandas as pd
from torchvision import transforms from torchvision import transforms
# Create a configuration and build a base model.
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
upscaler = aiuNN(config=ai_config)
# Load your base model and upscaler # Load your base model and upscaler
pretrained_model_path = "path/to/aiia/model" pretrained_model_path = "path/to/aiia/model"
base_model = AIIABase.load(pretrained_model_path, precision="bf16") base_model = AIIABase.from_pretrained(pretrained_model_path)
upscaler = aiuNN(base_model) upscaler.load_base_model(base_model)
# Create trainer with your dataset class # Create trainer with your dataset class
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset) trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name="aiunn", name="aiunn",
version="0.1.2", version="0.2.1",
packages=find_packages(where="src"), packages=find_packages(where="src"),
package_dir={"": "src"}, package_dir={"": "src"},
install_requires=[ install_requires=[

View File

@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
from .upsampler.config import aiuNNConfig from .upsampler.config import aiuNNConfig
from .inference.inference import aiuNNInference from .inference.inference import aiuNNInference
__version__ = "0.1.2" __version__ = "0.2.1"

View File

@ -10,6 +10,7 @@ from torch.utils.checkpoint import checkpoint
import gc import gc
import time import time
import shutil import shutil
import datetime
class EarlyStopping: class EarlyStopping:
@ -50,10 +51,16 @@ class aiuNNTrainer:
self.optimizer = None self.optimizer = None
self.scaler = GradScaler() self.scaler = GradScaler()
self.best_loss = float('inf') self.best_loss = float('inf')
self.use_checkpointing = True self.csv_path = None
self.checkpoint_dir = None
self.data_loader = None self.data_loader = None
self.validation_loader = None self.validation_loader = None
self.log_dir = None self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours
self.last_22_date = None
self.recent_checkpoints = []
self.current_epoch = 0
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None): def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
""" """
@ -110,23 +117,19 @@ class aiuNNTrainer:
return self.data_loader, self.validation_loader return self.data_loader, self.validation_loader
def _setup_logging(self, output_path): def _setup_logging(self, output_path):
"""Set up directory structure for logging and model checkpoints""" """Set up basic logging and checkpoint directory"""
timestamp = time.strftime("%Y%m%d-%H%M%S")
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
os.makedirs(self.log_dir, exist_ok=True)
# Create checkpoint directory # Create checkpoint directory
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints") self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up CSV logging # Set up CSV logging
self.csv_path = os.path.join(self.log_dir, 'training_log.csv') self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file: with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file) writer = csv.writer(file)
if self.validation_loader: if self.validation_loader:
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved']) writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
else: else:
writer.writerow(['Epoch', 'Train Loss', 'Improved']) writer.writerow(['Epoch', 'Train Loss'])
def _evaluate(self): def _evaluate(self):
"""Evaluate the model on validation data""" """Evaluate the model on validation data"""
@ -152,63 +155,99 @@ class aiuNNTrainer:
self.model.train() self.model.train()
return val_loss return val_loss
def _save_checkpoint(self, epoch, is_best=False): def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False):
"""Save model checkpoint""" """Save checkpoint with support for regular, best, and 22:00 saves"""
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt") if is_22:
best_model_path = os.path.join(self.log_dir, "best_model") today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date()
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
else:
checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt"
# Save the model checkpoint checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
self.model.save(checkpoint_path)
# If this is the best model so far, copy it to best_model checkpoint_data = {
'epoch': epoch,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_loss': self.best_loss,
'scaler_state_dict': self.scaler.state_dict()
}
torch.save(checkpoint_data, checkpoint_path)
# Save best model separately
if is_best: if is_best:
if os.path.exists(best_model_path): best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
shutil.rmtree(best_model_path) self.model.save_pretrained(best_model_path)
self.model.save(best_model_path)
print(f"Saved new best model with loss: {self.best_loss:.6f}") return checkpoint_path
def _handle_checkpoints(self, epoch, batch_count, is_improved):
"""Handle periodic and 22:00 checkpoint saving"""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2)))
# Regular interval checkpoint
if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
self._save_checkpoint(epoch, batch_count, is_improved)
self.last_checkpoint_time = current_time
# Special 22:00 checkpoint
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if is_22_oclock and self.last_22_date != current_dt.date():
self._save_checkpoint(epoch, batch_count, is_improved, is_22=True)
self.last_22_date = current_dt.date()
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
""" """Finetune the upscaler model"""
Finetune the upscaler model
Args:
output_path (str): Directory to save models and logs
epochs (int): Maximum number of training epochs
lr (float): Learning rate
patience (int): Early stopping patience
min_delta (float): Minimum improvement for early stopping
"""
# Check if data is loaded
if self.data_loader is None: if self.data_loader is None:
raise ValueError("Data not loaded. Call load_data first.") raise ValueError("Data not loaded. Call load_data first.")
# Setup optimizer # Setup optimizer and directories
self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up logging # Setup CSV logging
self._setup_logging(output_path) self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
writer.writerow(header)
# Load existing checkpoint if available
checkpoint_info = self.load_checkpoint()
start_epoch = checkpoint_info[0] if checkpoint_info else 0
start_batch = checkpoint_info[1] if checkpoint_info else 0
# Setup early stopping # Setup early stopping
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
self.best_loss = float('inf')
# Training loop # Training loop
self.model.train() self.model.train()
for epoch in range(start_epoch, epochs):
for epoch in range(epochs): self.current_epoch = epoch
# Training phase
epoch_loss = 0.0 epoch_loss = 0.0
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
for low_res, high_res in progress_bar: train_batches = list(enumerate(self.data_loader))
# Move data to GPU with channels_last format where possible start_idx = start_batch if epoch == start_epoch else 0
progress_bar = tqdm(train_batches[start_idx:],
initial=start_idx,
total=len(train_batches),
desc=f"Epoch {epoch + 1}/{epochs}")
for batch_idx, (low_res, high_res) in progress_bar:
# Training step
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True) high_res = high_res.to(self.device, non_blocking=True)
self.optimizer.zero_grad() self.optimizer.zero_grad()
with autocast(device_type=self.device.type): with autocast(device_type=self.device.type):
if self.use_checkpointing: if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
low_res.requires_grad_() low_res.requires_grad_()
outputs = checkpoint(self.model, low_res) outputs = checkpoint(self.model, low_res)
else: else:
@ -222,69 +261,109 @@ class aiuNNTrainer:
epoch_loss += loss.item() epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()}) progress_bar.set_postfix({'loss': loss.item()})
# Optionally delete variables to free memory # Handle checkpoints
self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss)
del low_res, high_res, outputs, loss del low_res, high_res, outputs, loss
# Calculate average epoch loss # End of epoch processing
avg_train_loss = epoch_loss / len(self.data_loader) avg_train_loss = epoch_loss / len(self.data_loader)
# Validation phase (if validation loader exists) # Validation phase
if self.validation_loader: if self.validation_loader:
val_loss = self._evaluate() / len(self.validation_loader) val_loss = self._evaluate() / len(self.validation_loader)
is_improved = val_loss < self.best_loss is_improved = val_loss < self.best_loss
if is_improved: if is_improved:
self.best_loss = val_loss self.best_loss = val_loss
# Log results # Log to CSV
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file: with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file) writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"]) writer.writerow([epoch + 1, avg_train_loss, val_loss])
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
else: else:
# If no validation, use training loss for improvement tracking
is_improved = avg_train_loss < self.best_loss is_improved = avg_train_loss < self.best_loss
if is_improved: if is_improved:
self.best_loss = avg_train_loss self.best_loss = avg_train_loss
# Log results # Log to CSV
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file: with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file) writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"]) writer.writerow([epoch + 1, avg_train_loss])
# Save checkpoint print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
self._save_checkpoint(epoch + 1, is_best=is_improved)
# Perform garbage collection and clear GPU cache after each epoch # Save best model if improved
gc.collect() if is_improved:
torch.cuda.empty_cache() best_model_path = os.path.join(output_path, "best_model")
self.model.save_pretrained(best_model_path)
# Check early stopping # Check early stopping
early_stopping(val_loss if self.validation_loader else avg_train_loss) if early_stopping(val_loss if self.validation_loader else avg_train_loss):
if early_stopping.early_stop:
print(f"Early stopping triggered at epoch {epoch + 1}") print(f"Early stopping triggered at epoch {epoch + 1}")
break break
# Cleanup
gc.collect()
torch.cuda.empty_cache()
return self.best_loss return self.best_loss
def load_checkpoint(self, specific_checkpoint=None):
"""Enhanced checkpoint loading with specific checkpoint support"""
if specific_checkpoint:
checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint)
else:
checkpoint_files = [f for f in os.listdir(self.checkpoint_dir)
if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
return None
checkpoint_files.sort(key=lambda x: os.path.getmtime(
os.path.join(self.checkpoint_dir, x)), reverse=True)
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0])
if not os.path.exists(checkpoint_path):
return None
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
self.best_loss = checkpoint['best_loss']
print(f"Loaded checkpoint from {checkpoint_path}")
return checkpoint['epoch'], checkpoint['batch']
def save(self, output_path=None): def save(self, output_path=None):
""" """
Save the best model to the specified path Save the best model to the specified path
Args: Args:
output_path (str, optional): Path to save the model. If None, uses the best model from training. output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training.
Returns:
str: Path where the model was saved
Raises:
ValueError: If no output path is specified and no checkpoint directory exists
""" """
if output_path is None and self.log_dir is not None: if output_path is None and self.checkpoint_dir is not None:
best_model_path = os.path.join(self.log_dir, "best_model") # First try to copy the best model if it exists
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
if os.path.exists(best_model_path): if os.path.exists(best_model_path):
print(f"Best model already saved at {best_model_path}") output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
return best_model_path shutil.copytree(best_model_path, output_path, dirs_exist_ok=True)
print(f"Copied best model to {output_path}")
return output_path
else: else:
output_path = os.path.join(self.log_dir, "final_model") # If no best model exists, save current model state
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
if output_path is None: if output_path is None:
raise ValueError("No output path specified and no training has been done yet.") raise ValueError("No output path specified and no checkpoint directory exists from training.")
self.model.save(output_path) self.model.save_pretrained(output_path)
print(f"Model saved to {output_path}") print(f"Model saved to {output_path}")
return output_path return output_path

View File

@ -12,13 +12,13 @@ class aiuNNInference:
Inference class for aiuNN upsampling model. Inference class for aiuNN upsampling model.
Handles model loading, image upscaling, and output processing. Handles model loading, image upscaling, and output processing.
""" """
def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None): def __init__(self, model_path: str, device: Optional[str] = None):
""" """
Initialize the inference class by loading the aiuNN model. Initialize the inference class by loading the aiuNN model.
Args: Args:
model_path: Path to the saved model directory model_path: Path to the saved model directory
precision: Optional precision setting ('fp16', 'bf16', or None for default)
device: Optional device specification ('cuda', 'cpu', or None for auto-detection) device: Optional device specification ('cuda', 'cpu', or None for auto-detection)
""" """
@ -30,7 +30,7 @@ class aiuNNInference:
self.device = device self.device = device
# Load the model with specified precision # Load the model with specified precision
self.model = aiuNN.load(model_path, precision=precision) self.model = aiuNN.from_pretrained(model_path)
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
@ -160,54 +160,11 @@ class aiuNNInference:
return binary_data return binary_data
def process_batch(self,
images: List[Union[str, Image.Image]],
output_dir: Optional[str] = None,
save_format: str = 'PNG',
return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]:
"""
Process multiple images in batch.
Args:
images: List of input images (paths or PIL Images)
output_dir: Optional directory to save results
save_format: Format to use when saving images
return_binary: Whether to return binary data instead of PIL Images
Returns:
List of processed images or binary data, or None if only saving
"""
results = []
for i, img in enumerate(images):
# Upscale the image
upscaled = self.upscale(img)
# Save if output directory is provided
if output_dir:
# Extract filename if input is a path
if isinstance(img, str):
filename = os.path.basename(img)
base, _ = os.path.splitext(filename)
else:
base = f"upscaled_{i}"
output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}")
self.save(upscaled, output_path, format=save_format)
# Add to results based on return type
if return_binary:
results.append(self.convert_to_binary(upscaled, format=save_format))
else:
results.append(upscaled)
return results if (not output_dir or return_binary or not save_format) else None
# Example usage (can be removed) # Example usage (can be removed)
if __name__ == "__main__": if __name__ == "__main__":
# Initialize inference with a model path # Initialize inference with a model path
inferencer = aiuNNInference("path/to/model", precision="bf16") inferencer = aiuNNInference("path/to/model")
# Upscale a single image # Upscale a single image
upscaled_image = inferencer.upscale("input_image.jpg") upscaled_image = inferencer.upscale("input_image.jpg")
@ -218,9 +175,3 @@ if __name__ == "__main__":
# Convert to binary # Convert to binary
binary_data = inferencer.convert_to_binary(upscaled_image) binary_data = inferencer.convert_to_binary(upscaled_image)
# Process a batch of images
inferencer.process_batch(
["image1.jpg", "image2.jpg"],
output_dir="output_folder",
save_format="PNG"
)

View File

@ -2,16 +2,16 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import warnings import warnings
from aiia.model.Model import AIIA, AIIAConfig, AIIABase from aiia.model.Model import AIIAConfig, AIIABase
from transformers import PreTrainedModel
from .config import aiuNNConfig from .config import aiuNNConfig
import warnings import warnings
class aiuNN(AIIA): class aiuNN(PreTrainedModel):
def __init__(self, base_model: AIIA, config:aiuNNConfig): config_class = aiuNNConfig
super().__init__(base_model.config) def __init__(self, config: aiuNNConfig):
self.base_model = base_model super().__init__(config)
# Pass the unified base configuration using the new parameter. # Pass the unified base configuration using the new parameter.
self.config = config self.config = config
@ -26,118 +26,18 @@ class aiuNN(AIIA):
) )
self.pixel_shuffle = nn.PixelShuffle(scale_factor) self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def load_base_model(self, base_model: PreTrainedModel):
self.base_model = base_model
def forward(self, x): def forward(self, x):
if self.base_model is None:
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
x = self.base_model(x) # Get base features x = self.base_model(x) # Get base features
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x return x
@classmethod
def load(cls, path, precision: str = None, **kwargs):
"""
Load a aiuNN model from disk with automatic detection of base model type.
Args:
path (str): Directory containing the stored configuration and model parameters.
precision (str, optional): Desired precision for the model's parameters.
**kwargs: Additional keyword arguments to override configuration parameters.
Returns:
An instance of aiuNN with loaded weights.
"""
# Load the configuration
config = aiuNNConfig.load(path)
# Determine the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dictionary
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
# Import all model types
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
# Helper function to detect base class type from key patterns
def detect_base_class_type(keys_prefix):
if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()):
return AIIABaseShared
else:
return AIIABase
# Detect base model type
base_model = None
# Check for AIIAmoe with multiple experts
if any("base_model.experts" in key for key in state_dict.keys()):
# Count the number of experts
max_expert_idx = -1
for key in state_dict.keys():
if "base_model.experts." in key:
try:
parts = key.split("base_model.experts.")[1].split(".")
expert_idx = int(parts[0])
max_expert_idx = max(max_expert_idx, expert_idx)
except (ValueError, IndexError):
pass
if max_expert_idx >= 0:
# Determine the type of base_cnn each expert is using
base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn")
# Create AIIAmoe with the detected expert count and base class
base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs)
# Check for AIIAchunked or AIIArecursive
elif any("base_model.chunked_cnn" in key for key in state_dict.keys()):
if any("recursion_depth" in key for key in state_dict.keys()):
# This is an AIIArecursive model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIArecursive(config, base_class=base_class, **kwargs)
else:
# This is an AIIAchunked model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIAchunked(config, base_class=base_class, **kwargs)
# Check for AIIAExpert
elif any("base_model.base_cnn" in key for key in state_dict.keys()):
# Determine which base class the expert is using
base_class = detect_base_class_type("base_model.base_cnn")
base_model = AIIAExpert(config, base_class=base_class, **kwargs)
# If none of the above, use AIIABase or AIIABaseShared directly
else:
base_class = detect_base_class_type("base_model")
base_model = base_class(config, **kwargs)
# Create the aiuNN model with the detected base model
model = cls(base_model, config=base_model.config)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in state_dict.items():
if torch.is_tensor(param):
state_dict[key] = param.to(dtype)
# Load the state dict
model.load_state_dict(state_dict)
return model
if __name__ == "__main__": if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig from aiia import AIIABase, AIIAConfig
@ -146,11 +46,11 @@ if __name__ == "__main__":
ai_config = aiuNNConfig() ai_config = aiuNNConfig()
base_model = AIIABase(config) base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly). # Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(base_model, config=ai_config) upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model (both configuration and weights). # Save the model (both configuration and weights).
upsampler.save("aiunn") upsampler.save_pretrained("aiunn")
# Now load using the overridden load method; this will load the complete model. # Now load using the overridden load method; this will load the complete model.
upsampler_loaded = aiuNN.load("aiunn", precision="bf16") upsampler_loaded = aiuNN.from_pretrained("aiunn")
print("Updated configuration:", upsampler_loaded.config.__dict__) print("Updated configuration:", upsampler_loaded.config.__dict__)

View File

@ -21,9 +21,8 @@ def real_model(tmp_path):
base_model = AIIABase(config) base_model = AIIABase(config)
# Make sure aiuNN is properly configured with all required attributes # Make sure aiuNN is properly configured with all required attributes
upsampler = aiuNN(base_model, config=ai_config) upsampler = aiuNN(config=ai_config)
# Ensure the upsample attribute is properly set if needed upsampler.load_base_model(base_model)
# upsampler.upsample = ... # Add any necessary initialization
# Save the model and config to temporary directory # Save the model and config to temporary directory
save_path = str(model_dir / "save") save_path = str(model_dir / "save")
@ -40,10 +39,10 @@ def real_model(tmp_path):
json.dump(config_data, f) json.dump(config_data, f)
# Save model # Save model
upsampler.save(save_path) upsampler.save_pretrained(save_path)
# Load model in inference mode # Load model in inference mode
inference_model = aiuNNInference(model_path=save_path, precision='fp16', device='cpu') inference_model = aiuNNInference(model_path=save_path, device='cpu')
return inference_model return inference_model
@ -88,12 +87,3 @@ def test_convert_to_binary(inference):
result = inference.convert_to_binary(test_image) result = inference.convert_to_binary(test_image)
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
def test_process_batch(inference):
# Create test images
test_array = np.zeros((100, 100, 3), dtype=np.uint8)
test_images = [Image.fromarray(test_array) for _ in range(2)]
results = inference.process_batch(test_images)
assert len(results) == 2
assert all(isinstance(img, Image.Image) for img in results)

View File

@ -10,39 +10,21 @@ def test_save_and_load_model():
config = AIIAConfig() config = AIIAConfig()
ai_config = aiuNNConfig() ai_config = aiuNNConfig()
base_model = AIIABase(config) base_model = AIIABase(config)
upsampler = aiuNN(base_model, config=ai_config) upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model # Save the model
save_path = os.path.join(tmpdirname, "model") save_path = os.path.join(tmpdirname, "model")
upsampler.save(save_path) upsampler.save_pretrained(save_path)
# Load the model # Load the model
loaded_upsampler = aiuNN.load(save_path) loaded_upsampler = aiuNN.from_pretrained(save_path)
# Verify that the loaded model is the same as the original model # Verify that the loaded model is the same as the original model
assert isinstance(loaded_upsampler, aiuNN) assert isinstance(loaded_upsampler, aiuNN)
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__ assert loaded_upsampler.config.hidden_size == upsampler.config.hidden_size
assert loaded_upsampler.config._activation_function == upsampler.config._activation_function
assert loaded_upsampler.config.architectures == upsampler.config.architectures
def test_save_and_load_model_with_precision():
# Create a temporary directory to save the model
with tempfile.TemporaryDirectory() as tmpdirname:
# Create configurations and build a base model
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
upsampler = aiuNN(base_model, config=ai_config)
# Save the model
save_path = os.path.join(tmpdirname, "model")
upsampler.save(save_path)
# Load the model with precision 'bf16'
loaded_upsampler = aiuNN.load(save_path, precision="bf16")
# Verify that the loaded model is the same as the original model
assert isinstance(loaded_upsampler, aiuNN)
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__
if __name__ == "__main__": if __name__ == "__main__":
test_save_and_load_model() test_save_and_load_model()
test_save_and_load_model_with_precision()