Compare commits

..

12 Commits

Author SHA1 Message Date
Falko Victor Habel fd924fa024 Merge pull request 'develop' (#15) from develop into main
Run VectorLoader Script / Explore-Gitea-Actions (push) Successful in 20s Details
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 38s Details
Reviewed-on: #15
2025-04-20 20:50:50 +00:00
Falko Victor Habel 96e14b9674 Merge pull request 'feat/checkpoints' (#14) from feat/checkpoints into develop
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 39s Details
Reviewed-on: #14
2025-04-20 20:44:27 +00:00
Falko Victor Habel f3e59a6586 updated readme to feature tf support
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 39s Details
2025-04-20 22:36:43 +02:00
Falko Victor Habel 38530d5d44 increased version number 2025-04-20 22:29:08 +02:00
Falko Victor Habel b0d0b41944 updated trainer to save checkpooints after n hours and at 22 o'clock with the mission to safe energy 2025-04-20 22:28:30 +02:00
Falko Victor Habel ef19e24f11 Merge pull request 'feat/tf_support' (#13) from feat/tf_support into develop
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 37s Details
Reviewed-on: #13
2025-04-19 20:59:13 +00:00
Falko Victor Habel 45d6802cd7 increaed software version
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 38s Details
2025-04-19 22:54:27 +02:00
Falko Victor Habel c4e9432375 dropped full embedding circle
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Has been cancelled Details
2025-04-19 22:53:58 +02:00
Falko Victor Habel 391a03baed updated tests to match new inference na tf supported model
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Has been cancelled Details
2025-04-19 22:53:35 +02:00
Falko Victor Habel ac3fabd55f dropped batch processing and dropped fp16 loading
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Has been cancelled Details
2025-04-19 22:53:13 +02:00
Falko Victor Habel ced7e8a214 moved to transformer support, currently dropped fp16 load support 2025-04-19 22:53:00 +02:00
Falko Victor Habel 16f8de2175 first class but load is still missing, not complete
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Failing after 42s Details
2025-04-18 23:45:33 +02:00
9 changed files with 202 additions and 293 deletions

View File

@ -34,4 +34,4 @@ jobs:
VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }}
run: |
cd VectorLoader
python -m src.run --full
python -m src.run

View File

@ -26,15 +26,22 @@ pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git
Here's a basic example of how to use `aiuNN` for image upscaling:
```python src/main.py
from aiia import AIIABase
from aiia import AIIABase, AIIAConfig
from aiunn import aiuNN, aiuNNTrainer
import pandas as pd
from torchvision import transforms
# Create a configuration and build a base model.
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
upscaler = aiuNN(config=ai_config)
# Load your base model and upscaler
pretrained_model_path = "path/to/aiia/model"
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
upscaler = aiuNN(base_model)
base_model = AIIABase.from_pretrained(pretrained_model_path)
upscaler.load_base_model(base_model)
# Create trainer with your dataset class
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name="aiunn",
version="0.1.2",
version="0.2.1",
packages=find_packages(where="src"),
package_dir={"": "src"},
install_requires=[

View File

@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
from .upsampler.config import aiuNNConfig
from .inference.inference import aiuNNInference
__version__ = "0.1.2"
__version__ = "0.2.1"

View File

@ -10,6 +10,7 @@ from torch.utils.checkpoint import checkpoint
import gc
import time
import shutil
import datetime
class EarlyStopping:
@ -50,10 +51,16 @@ class aiuNNTrainer:
self.optimizer = None
self.scaler = GradScaler()
self.best_loss = float('inf')
self.use_checkpointing = True
self.csv_path = None
self.checkpoint_dir = None
self.data_loader = None
self.validation_loader = None
self.log_dir = None
self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours
self.last_22_date = None
self.recent_checkpoints = []
self.current_epoch = 0
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
"""
@ -110,23 +117,19 @@ class aiuNNTrainer:
return self.data_loader, self.validation_loader
def _setup_logging(self, output_path):
"""Set up directory structure for logging and model checkpoints"""
timestamp = time.strftime("%Y%m%d-%H%M%S")
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
os.makedirs(self.log_dir, exist_ok=True)
"""Set up basic logging and checkpoint directory"""
# Create checkpoint directory
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints")
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up CSV logging
self.csv_path = os.path.join(self.log_dir, 'training_log.csv')
self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
if self.validation_loader:
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved'])
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
else:
writer.writerow(['Epoch', 'Train Loss', 'Improved'])
writer.writerow(['Epoch', 'Train Loss'])
def _evaluate(self):
"""Evaluate the model on validation data"""
@ -152,63 +155,99 @@ class aiuNNTrainer:
self.model.train()
return val_loss
def _save_checkpoint(self, epoch, is_best=False):
"""Save model checkpoint"""
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt")
best_model_path = os.path.join(self.log_dir, "best_model")
def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False):
"""Save checkpoint with support for regular, best, and 22:00 saves"""
if is_22:
today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date()
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
else:
checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt"
# Save the model checkpoint
self.model.save(checkpoint_path)
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
# If this is the best model so far, copy it to best_model
checkpoint_data = {
'epoch': epoch,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_loss': self.best_loss,
'scaler_state_dict': self.scaler.state_dict()
}
torch.save(checkpoint_data, checkpoint_path)
# Save best model separately
if is_best:
if os.path.exists(best_model_path):
shutil.rmtree(best_model_path)
self.model.save(best_model_path)
print(f"Saved new best model with loss: {self.best_loss:.6f}")
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
self.model.save_pretrained(best_model_path)
return checkpoint_path
def _handle_checkpoints(self, epoch, batch_count, is_improved):
"""Handle periodic and 22:00 checkpoint saving"""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2)))
# Regular interval checkpoint
if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
self._save_checkpoint(epoch, batch_count, is_improved)
self.last_checkpoint_time = current_time
# Special 22:00 checkpoint
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if is_22_oclock and self.last_22_date != current_dt.date():
self._save_checkpoint(epoch, batch_count, is_improved, is_22=True)
self.last_22_date = current_dt.date()
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
"""
Finetune the upscaler model
Args:
output_path (str): Directory to save models and logs
epochs (int): Maximum number of training epochs
lr (float): Learning rate
patience (int): Early stopping patience
min_delta (float): Minimum improvement for early stopping
"""
# Check if data is loaded
"""Finetune the upscaler model"""
if self.data_loader is None:
raise ValueError("Data not loaded. Call load_data first.")
# Setup optimizer
# Setup optimizer and directories
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up logging
self._setup_logging(output_path)
# Setup CSV logging
self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
writer.writerow(header)
# Load existing checkpoint if available
checkpoint_info = self.load_checkpoint()
start_epoch = checkpoint_info[0] if checkpoint_info else 0
start_batch = checkpoint_info[1] if checkpoint_info else 0
# Setup early stopping
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
self.best_loss = float('inf')
# Training loop
self.model.train()
for epoch in range(epochs):
# Training phase
for epoch in range(start_epoch, epochs):
self.current_epoch = epoch
epoch_loss = 0.0
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
for low_res, high_res in progress_bar:
# Move data to GPU with channels_last format where possible
train_batches = list(enumerate(self.data_loader))
start_idx = start_batch if epoch == start_epoch else 0
progress_bar = tqdm(train_batches[start_idx:],
initial=start_idx,
total=len(train_batches),
desc=f"Epoch {epoch + 1}/{epochs}")
for batch_idx, (low_res, high_res) in progress_bar:
# Training step
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True)
self.optimizer.zero_grad()
with autocast(device_type=self.device.type):
if self.use_checkpointing:
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
low_res.requires_grad_()
outputs = checkpoint(self.model, low_res)
else:
@ -222,69 +261,109 @@ class aiuNNTrainer:
epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
# Optionally delete variables to free memory
# Handle checkpoints
self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss)
del low_res, high_res, outputs, loss
# Calculate average epoch loss
# End of epoch processing
avg_train_loss = epoch_loss / len(self.data_loader)
# Validation phase (if validation loader exists)
# Validation phase
if self.validation_loader:
val_loss = self._evaluate() / len(self.validation_loader)
is_improved = val_loss < self.best_loss
if is_improved:
self.best_loss = val_loss
# Log results
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
# Log to CSV
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"])
writer.writerow([epoch + 1, avg_train_loss, val_loss])
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
else:
# If no validation, use training loss for improvement tracking
is_improved = avg_train_loss < self.best_loss
if is_improved:
self.best_loss = avg_train_loss
# Log results
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
# Log to CSV
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"])
writer.writerow([epoch + 1, avg_train_loss])
# Save checkpoint
self._save_checkpoint(epoch + 1, is_best=is_improved)
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
# Perform garbage collection and clear GPU cache after each epoch
gc.collect()
torch.cuda.empty_cache()
# Save best model if improved
if is_improved:
best_model_path = os.path.join(output_path, "best_model")
self.model.save_pretrained(best_model_path)
# Check early stopping
early_stopping(val_loss if self.validation_loader else avg_train_loss)
if early_stopping.early_stop:
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
print(f"Early stopping triggered at epoch {epoch + 1}")
break
# Cleanup
gc.collect()
torch.cuda.empty_cache()
return self.best_loss
def load_checkpoint(self, specific_checkpoint=None):
"""Enhanced checkpoint loading with specific checkpoint support"""
if specific_checkpoint:
checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint)
else:
checkpoint_files = [f for f in os.listdir(self.checkpoint_dir)
if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
return None
checkpoint_files.sort(key=lambda x: os.path.getmtime(
os.path.join(self.checkpoint_dir, x)), reverse=True)
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0])
if not os.path.exists(checkpoint_path):
return None
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
self.best_loss = checkpoint['best_loss']
print(f"Loaded checkpoint from {checkpoint_path}")
return checkpoint['epoch'], checkpoint['batch']
def save(self, output_path=None):
"""
Save the best model to the specified path
Args:
output_path (str, optional): Path to save the model. If None, uses the best model from training.
output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training.
Returns:
str: Path where the model was saved
Raises:
ValueError: If no output path is specified and no checkpoint directory exists
"""
if output_path is None and self.log_dir is not None:
best_model_path = os.path.join(self.log_dir, "best_model")
if output_path is None and self.checkpoint_dir is not None:
# First try to copy the best model if it exists
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
if os.path.exists(best_model_path):
print(f"Best model already saved at {best_model_path}")
return best_model_path
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
shutil.copytree(best_model_path, output_path, dirs_exist_ok=True)
print(f"Copied best model to {output_path}")
return output_path
else:
output_path = os.path.join(self.log_dir, "final_model")
# If no best model exists, save current model state
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
if output_path is None:
raise ValueError("No output path specified and no training has been done yet.")
raise ValueError("No output path specified and no checkpoint directory exists from training.")
self.model.save(output_path)
self.model.save_pretrained(output_path)
print(f"Model saved to {output_path}")
return output_path

View File

@ -12,13 +12,13 @@ class aiuNNInference:
Inference class for aiuNN upsampling model.
Handles model loading, image upscaling, and output processing.
"""
def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None):
def __init__(self, model_path: str, device: Optional[str] = None):
"""
Initialize the inference class by loading the aiuNN model.
Args:
model_path: Path to the saved model directory
precision: Optional precision setting ('fp16', 'bf16', or None for default)
device: Optional device specification ('cuda', 'cpu', or None for auto-detection)
"""
@ -30,7 +30,7 @@ class aiuNNInference:
self.device = device
# Load the model with specified precision
self.model = aiuNN.load(model_path, precision=precision)
self.model = aiuNN.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
@ -160,54 +160,11 @@ class aiuNNInference:
return binary_data
def process_batch(self,
images: List[Union[str, Image.Image]],
output_dir: Optional[str] = None,
save_format: str = 'PNG',
return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]:
"""
Process multiple images in batch.
Args:
images: List of input images (paths or PIL Images)
output_dir: Optional directory to save results
save_format: Format to use when saving images
return_binary: Whether to return binary data instead of PIL Images
Returns:
List of processed images or binary data, or None if only saving
"""
results = []
for i, img in enumerate(images):
# Upscale the image
upscaled = self.upscale(img)
# Save if output directory is provided
if output_dir:
# Extract filename if input is a path
if isinstance(img, str):
filename = os.path.basename(img)
base, _ = os.path.splitext(filename)
else:
base = f"upscaled_{i}"
output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}")
self.save(upscaled, output_path, format=save_format)
# Add to results based on return type
if return_binary:
results.append(self.convert_to_binary(upscaled, format=save_format))
else:
results.append(upscaled)
return results if (not output_dir or return_binary or not save_format) else None
# Example usage (can be removed)
if __name__ == "__main__":
# Initialize inference with a model path
inferencer = aiuNNInference("path/to/model", precision="bf16")
inferencer = aiuNNInference("path/to/model")
# Upscale a single image
upscaled_image = inferencer.upscale("input_image.jpg")
@ -218,9 +175,3 @@ if __name__ == "__main__":
# Convert to binary
binary_data = inferencer.convert_to_binary(upscaled_image)
# Process a batch of images
inferencer.process_batch(
["image1.jpg", "image2.jpg"],
output_dir="output_folder",
save_format="PNG"
)

View File

@ -2,16 +2,16 @@ import os
import torch
import torch.nn as nn
import warnings
from aiia.model.Model import AIIA, AIIAConfig, AIIABase
from aiia.model.Model import AIIAConfig, AIIABase
from transformers import PreTrainedModel
from .config import aiuNNConfig
import warnings
class aiuNN(AIIA):
def __init__(self, base_model: AIIA, config:aiuNNConfig):
super().__init__(base_model.config)
self.base_model = base_model
class aiuNN(PreTrainedModel):
config_class = aiuNNConfig
def __init__(self, config: aiuNNConfig):
super().__init__(config)
# Pass the unified base configuration using the new parameter.
self.config = config
@ -26,118 +26,18 @@ class aiuNN(AIIA):
)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def load_base_model(self, base_model: PreTrainedModel):
self.base_model = base_model
def forward(self, x):
if self.base_model is None:
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
x = self.base_model(x) # Get base features
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x
@classmethod
def load(cls, path, precision: str = None, **kwargs):
"""
Load a aiuNN model from disk with automatic detection of base model type.
Args:
path (str): Directory containing the stored configuration and model parameters.
precision (str, optional): Desired precision for the model's parameters.
**kwargs: Additional keyword arguments to override configuration parameters.
Returns:
An instance of aiuNN with loaded weights.
"""
# Load the configuration
config = aiuNNConfig.load(path)
# Determine the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dictionary
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
# Import all model types
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
# Helper function to detect base class type from key patterns
def detect_base_class_type(keys_prefix):
if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()):
return AIIABaseShared
else:
return AIIABase
# Detect base model type
base_model = None
# Check for AIIAmoe with multiple experts
if any("base_model.experts" in key for key in state_dict.keys()):
# Count the number of experts
max_expert_idx = -1
for key in state_dict.keys():
if "base_model.experts." in key:
try:
parts = key.split("base_model.experts.")[1].split(".")
expert_idx = int(parts[0])
max_expert_idx = max(max_expert_idx, expert_idx)
except (ValueError, IndexError):
pass
if max_expert_idx >= 0:
# Determine the type of base_cnn each expert is using
base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn")
# Create AIIAmoe with the detected expert count and base class
base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs)
# Check for AIIAchunked or AIIArecursive
elif any("base_model.chunked_cnn" in key for key in state_dict.keys()):
if any("recursion_depth" in key for key in state_dict.keys()):
# This is an AIIArecursive model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIArecursive(config, base_class=base_class, **kwargs)
else:
# This is an AIIAchunked model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIAchunked(config, base_class=base_class, **kwargs)
# Check for AIIAExpert
elif any("base_model.base_cnn" in key for key in state_dict.keys()):
# Determine which base class the expert is using
base_class = detect_base_class_type("base_model.base_cnn")
base_model = AIIAExpert(config, base_class=base_class, **kwargs)
# If none of the above, use AIIABase or AIIABaseShared directly
else:
base_class = detect_base_class_type("base_model")
base_model = base_class(config, **kwargs)
# Create the aiuNN model with the detected base model
model = cls(base_model, config=base_model.config)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in state_dict.items():
if torch.is_tensor(param):
state_dict[key] = param.to(dtype)
# Load the state dict
model.load_state_dict(state_dict)
return model
if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig
@ -146,11 +46,11 @@ if __name__ == "__main__":
ai_config = aiuNNConfig()
base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(base_model, config=ai_config)
upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model (both configuration and weights).
upsampler.save("aiunn")
upsampler.save_pretrained("aiunn")
# Now load using the overridden load method; this will load the complete model.
upsampler_loaded = aiuNN.load("aiunn", precision="bf16")
upsampler_loaded = aiuNN.from_pretrained("aiunn")
print("Updated configuration:", upsampler_loaded.config.__dict__)

View File

@ -21,9 +21,8 @@ def real_model(tmp_path):
base_model = AIIABase(config)
# Make sure aiuNN is properly configured with all required attributes
upsampler = aiuNN(base_model, config=ai_config)
# Ensure the upsample attribute is properly set if needed
# upsampler.upsample = ... # Add any necessary initialization
upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model and config to temporary directory
save_path = str(model_dir / "save")
@ -40,10 +39,10 @@ def real_model(tmp_path):
json.dump(config_data, f)
# Save model
upsampler.save(save_path)
upsampler.save_pretrained(save_path)
# Load model in inference mode
inference_model = aiuNNInference(model_path=save_path, precision='fp16', device='cpu')
inference_model = aiuNNInference(model_path=save_path, device='cpu')
return inference_model
@ -88,12 +87,3 @@ def test_convert_to_binary(inference):
result = inference.convert_to_binary(test_image)
assert isinstance(result, bytes)
assert len(result) > 0
def test_process_batch(inference):
# Create test images
test_array = np.zeros((100, 100, 3), dtype=np.uint8)
test_images = [Image.fromarray(test_array) for _ in range(2)]
results = inference.process_batch(test_images)
assert len(results) == 2
assert all(isinstance(img, Image.Image) for img in results)

View File

@ -10,39 +10,21 @@ def test_save_and_load_model():
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
upsampler = aiuNN(base_model, config=ai_config)
upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model
save_path = os.path.join(tmpdirname, "model")
upsampler.save(save_path)
upsampler.save_pretrained(save_path)
# Load the model
loaded_upsampler = aiuNN.load(save_path)
loaded_upsampler = aiuNN.from_pretrained(save_path)
# Verify that the loaded model is the same as the original model
assert isinstance(loaded_upsampler, aiuNN)
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__
assert loaded_upsampler.config.hidden_size == upsampler.config.hidden_size
assert loaded_upsampler.config._activation_function == upsampler.config._activation_function
assert loaded_upsampler.config.architectures == upsampler.config.architectures
def test_save_and_load_model_with_precision():
# Create a temporary directory to save the model
with tempfile.TemporaryDirectory() as tmpdirname:
# Create configurations and build a base model
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
upsampler = aiuNN(base_model, config=ai_config)
# Save the model
save_path = os.path.join(tmpdirname, "model")
upsampler.save(save_path)
# Load the model with precision 'bf16'
loaded_upsampler = aiuNN.load(save_path, precision="bf16")
# Verify that the loaded model is the same as the original model
assert isinstance(loaded_upsampler, aiuNN)
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__
if __name__ == "__main__":
test_save_and_load_model()
test_save_and_load_model_with_precision()