diff --git a/.gitea/workflows/embed.yaml b/.gitea/workflows/embed.yaml index 22c5bda..091db58 100644 --- a/.gitea/workflows/embed.yaml +++ b/.gitea/workflows/embed.yaml @@ -34,4 +34,4 @@ jobs: VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }} run: | cd VectorLoader - python -m src.run --full + python -m src.run diff --git a/README.md b/README.md index aeb0837..802ba14 100644 --- a/README.md +++ b/README.md @@ -26,15 +26,22 @@ pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git Here's a basic example of how to use `aiuNN` for image upscaling: ```python src/main.py -from aiia import AIIABase +from aiia import AIIABase, AIIAConfig from aiunn import aiuNN, aiuNNTrainer import pandas as pd from torchvision import transforms +# Create a configuration and build a base model. +config = AIIAConfig() +ai_config = aiuNNConfig() + +base_model = AIIABase(config) +upscaler = aiuNN(config=ai_config) + # Load your base model and upscaler pretrained_model_path = "path/to/aiia/model" -base_model = AIIABase.load(pretrained_model_path, precision="bf16") -upscaler = aiuNN(base_model) +base_model = AIIABase.from_pretrained(pretrained_model_path) +upscaler.load_base_model(base_model) # Create trainer with your dataset class trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset) @@ -105,19 +112,19 @@ class UpscaleDataset(Dataset): # Open image bytes with Pillow and convert to RGBA first low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA') - + # Create a new RGB image with black background low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0)) high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0)) - + # Composite the original image over the black background low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3]) high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3]) - + # Now we have true 3-channel RGB images with transparent areas converted to black low_res = low_res_rgb high_res = high_res_rgb - + # If a transform is provided (e.g. conversion to Tensor), apply it if self.transform: low_res = self.transform(low_res) @@ -127,4 +134,4 @@ class UpscaleDataset(Dataset): print(f"\nError at index {idx}: {str(e)}") self.failed_indices.add(idx) return self[(idx + 1) % len(self)] -``` +``` \ No newline at end of file diff --git a/setup.py b/setup.py index b6c17c3..e934f29 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name="aiunn", - version="0.1.2", + version="0.2.1", packages=find_packages(where="src"), package_dir={"": "src"}, install_requires=[ diff --git a/src/aiunn/__init__.py b/src/aiunn/__init__.py index 16f468f..03bd20e 100644 --- a/src/aiunn/__init__.py +++ b/src/aiunn/__init__.py @@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN from .upsampler.config import aiuNNConfig from .inference.inference import aiuNNInference -__version__ = "0.1.2" \ No newline at end of file +__version__ = "0.2.1" \ No newline at end of file diff --git a/src/aiunn/finetune/trainer.py b/src/aiunn/finetune/trainer.py index b94d57d..33b2533 100644 --- a/src/aiunn/finetune/trainer.py +++ b/src/aiunn/finetune/trainer.py @@ -10,6 +10,7 @@ from torch.utils.checkpoint import checkpoint import gc import time import shutil +import datetime class EarlyStopping: @@ -50,10 +51,16 @@ class aiuNNTrainer: self.optimizer = None self.scaler = GradScaler() self.best_loss = float('inf') - self.use_checkpointing = True + self.csv_path = None + self.checkpoint_dir = None self.data_loader = None self.validation_loader = None - self.log_dir = None + self.last_checkpoint_time = time.time() + self.checkpoint_interval = 2 * 60 * 60 # 2 hours + self.last_22_date = None + self.recent_checkpoints = [] + self.current_epoch = 0 + def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None): """ @@ -110,23 +117,19 @@ class aiuNNTrainer: return self.data_loader, self.validation_loader def _setup_logging(self, output_path): - """Set up directory structure for logging and model checkpoints""" - timestamp = time.strftime("%Y%m%d-%H%M%S") - self.log_dir = os.path.join(output_path, f"training_run_{timestamp}") - os.makedirs(self.log_dir, exist_ok=True) - + """Set up basic logging and checkpoint directory""" # Create checkpoint directory - self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints") + self.checkpoint_dir = os.path.join(output_path, "checkpoints") os.makedirs(self.checkpoint_dir, exist_ok=True) # Set up CSV logging - self.csv_path = os.path.join(self.log_dir, 'training_log.csv') + self.csv_path = os.path.join(output_path, 'training_log.csv') with open(self.csv_path, mode='w', newline='') as file: writer = csv.writer(file) if self.validation_loader: - writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved']) + writer.writerow(['Epoch', 'Train Loss', 'Validation Loss']) else: - writer.writerow(['Epoch', 'Train Loss', 'Improved']) + writer.writerow(['Epoch', 'Train Loss']) def _evaluate(self): """Evaluate the model on validation data""" @@ -152,64 +155,100 @@ class aiuNNTrainer: self.model.train() return val_loss - def _save_checkpoint(self, epoch, is_best=False): - """Save model checkpoint""" - checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt") - best_model_path = os.path.join(self.log_dir, "best_model") + def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False): + """Save checkpoint with support for regular, best, and 22:00 saves""" + if is_22: + today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date() + checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" + else: + checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt" + + checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name) - # Save the model checkpoint - self.model.save(checkpoint_path) + checkpoint_data = { + 'epoch': epoch, + 'batch': batch_count, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'best_loss': self.best_loss, + 'scaler_state_dict': self.scaler.state_dict() + } - # If this is the best model so far, copy it to best_model + torch.save(checkpoint_data, checkpoint_path) + + # Save best model separately if is_best: - if os.path.exists(best_model_path): - shutil.rmtree(best_model_path) - self.model.save(best_model_path) - print(f"Saved new best model with loss: {self.best_loss:.6f}") - - def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): - """ - Finetune the upscaler model + best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model") + self.model.save_pretrained(best_model_path) + + return checkpoint_path + + def _handle_checkpoints(self, epoch, batch_count, is_improved): + """Handle periodic and 22:00 checkpoint saving""" + current_time = time.time() + current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) - Args: - output_path (str): Directory to save models and logs - epochs (int): Maximum number of training epochs - lr (float): Learning rate - patience (int): Early stopping patience - min_delta (float): Minimum improvement for early stopping - """ - # Check if data is loaded + # Regular interval checkpoint + if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval: + self._save_checkpoint(epoch, batch_count, is_improved) + self.last_checkpoint_time = current_time + + # Special 22:00 checkpoint + is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15 + if is_22_oclock and self.last_22_date != current_dt.date(): + self._save_checkpoint(epoch, batch_count, is_improved, is_22=True) + self.last_22_date = current_dt.date() + + def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): + """Finetune the upscaler model""" if self.data_loader is None: raise ValueError("Data not loaded. Call load_data first.") - # Setup optimizer + # Setup optimizer and directories self.optimizer = optim.Adam(self.model.parameters(), lr=lr) + self.checkpoint_dir = os.path.join(output_path, "checkpoints") + os.makedirs(self.checkpoint_dir, exist_ok=True) - # Set up logging - self._setup_logging(output_path) + # Setup CSV logging + self.csv_path = os.path.join(output_path, 'training_log.csv') + with open(self.csv_path, mode='w', newline='') as file: + writer = csv.writer(file) + header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss'] + writer.writerow(header) + + # Load existing checkpoint if available + checkpoint_info = self.load_checkpoint() + start_epoch = checkpoint_info[0] if checkpoint_info else 0 + start_batch = checkpoint_info[1] if checkpoint_info else 0 # Setup early stopping early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) + self.best_loss = float('inf') # Training loop self.model.train() - - for epoch in range(epochs): - # Training phase + for epoch in range(start_epoch, epochs): + self.current_epoch = epoch epoch_loss = 0.0 - progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}") - for low_res, high_res in progress_bar: - # Move data to GPU with channels_last format where possible + train_batches = list(enumerate(self.data_loader)) + start_idx = start_batch if epoch == start_epoch else 0 + + progress_bar = tqdm(train_batches[start_idx:], + initial=start_idx, + total=len(train_batches), + desc=f"Epoch {epoch + 1}/{epochs}") + + for batch_idx, (low_res, high_res) in progress_bar: + # Training step low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) high_res = high_res.to(self.device, non_blocking=True) self.optimizer.zero_grad() with autocast(device_type=self.device.type): - if self.use_checkpointing: - # Ensure the input tensor requires gradient so that checkpointing records the computation graph - low_res.requires_grad_() + if hasattr(self, 'use_checkpointing') and self.use_checkpointing: + low_res.requires_grad_() outputs = checkpoint(self.model, low_res) else: outputs = self.model(low_res) @@ -222,69 +261,109 @@ class aiuNNTrainer: epoch_loss += loss.item() progress_bar.set_postfix({'loss': loss.item()}) - # Optionally delete variables to free memory + # Handle checkpoints + self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss) + del low_res, high_res, outputs, loss - # Calculate average epoch loss + # End of epoch processing avg_train_loss = epoch_loss / len(self.data_loader) - # Validation phase (if validation loader exists) + # Validation phase if self.validation_loader: val_loss = self._evaluate() / len(self.validation_loader) is_improved = val_loss < self.best_loss if is_improved: self.best_loss = val_loss - # Log results - print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}") + # Log to CSV with open(self.csv_path, mode='a', newline='') as file: writer = csv.writer(file) - writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"]) + writer.writerow([epoch + 1, avg_train_loss, val_loss]) + + print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}") else: - # If no validation, use training loss for improvement tracking is_improved = avg_train_loss < self.best_loss if is_improved: self.best_loss = avg_train_loss - # Log results - print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}") + # Log to CSV with open(self.csv_path, mode='a', newline='') as file: writer = csv.writer(file) - writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"]) + writer.writerow([epoch + 1, avg_train_loss]) + + print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}") - # Save checkpoint - self._save_checkpoint(epoch + 1, is_best=is_improved) - - # Perform garbage collection and clear GPU cache after each epoch - gc.collect() - torch.cuda.empty_cache() + # Save best model if improved + if is_improved: + best_model_path = os.path.join(output_path, "best_model") + self.model.save_pretrained(best_model_path) # Check early stopping - early_stopping(val_loss if self.validation_loader else avg_train_loss) - if early_stopping.early_stop: + if early_stopping(val_loss if self.validation_loader else avg_train_loss): print(f"Early stopping triggered at epoch {epoch + 1}") break + + # Cleanup + gc.collect() + torch.cuda.empty_cache() return self.best_loss - + + def load_checkpoint(self, specific_checkpoint=None): + """Enhanced checkpoint loading with specific checkpoint support""" + if specific_checkpoint: + checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint) + else: + checkpoint_files = [f for f in os.listdir(self.checkpoint_dir) + if f.startswith("checkpoint_") and f.endswith(".pt")] + if not checkpoint_files: + return None + + checkpoint_files.sort(key=lambda x: os.path.getmtime( + os.path.join(self.checkpoint_dir, x)), reverse=True) + checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0]) + + if not os.path.exists(checkpoint_path): + return None + + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.scaler.load_state_dict(checkpoint['scaler_state_dict']) + self.best_loss = checkpoint['best_loss'] + + print(f"Loaded checkpoint from {checkpoint_path}") + return checkpoint['epoch'], checkpoint['batch'] + def save(self, output_path=None): """ Save the best model to the specified path Args: - output_path (str, optional): Path to save the model. If None, uses the best model from training. + output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training. + + Returns: + str: Path where the model was saved + + Raises: + ValueError: If no output path is specified and no checkpoint directory exists """ - if output_path is None and self.log_dir is not None: - best_model_path = os.path.join(self.log_dir, "best_model") + if output_path is None and self.checkpoint_dir is not None: + # First try to copy the best model if it exists + best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model") if os.path.exists(best_model_path): - print(f"Best model already saved at {best_model_path}") - return best_model_path + output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model") + shutil.copytree(best_model_path, output_path, dirs_exist_ok=True) + print(f"Copied best model to {output_path}") + return output_path else: - output_path = os.path.join(self.log_dir, "final_model") + # If no best model exists, save current model state + output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model") if output_path is None: - raise ValueError("No output path specified and no training has been done yet.") + raise ValueError("No output path specified and no checkpoint directory exists from training.") - self.model.save(output_path) + self.model.save_pretrained(output_path) print(f"Model saved to {output_path}") return output_path \ No newline at end of file diff --git a/src/aiunn/inference/inference.py b/src/aiunn/inference/inference.py index d288931..1ff5994 100644 --- a/src/aiunn/inference/inference.py +++ b/src/aiunn/inference/inference.py @@ -12,13 +12,13 @@ class aiuNNInference: Inference class for aiuNN upsampling model. Handles model loading, image upscaling, and output processing. """ - def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None): + def __init__(self, model_path: str, device: Optional[str] = None): """ Initialize the inference class by loading the aiuNN model. Args: model_path: Path to the saved model directory - precision: Optional precision setting ('fp16', 'bf16', or None for default) + device: Optional device specification ('cuda', 'cpu', or None for auto-detection) """ @@ -30,7 +30,7 @@ class aiuNNInference: self.device = device # Load the model with specified precision - self.model = aiuNN.load(model_path, precision=precision) + self.model = aiuNN.from_pretrained(model_path) self.model.to(self.device) self.model.eval() @@ -160,54 +160,11 @@ class aiuNNInference: return binary_data - def process_batch(self, - images: List[Union[str, Image.Image]], - output_dir: Optional[str] = None, - save_format: str = 'PNG', - return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]: - """ - Process multiple images in batch. - - Args: - images: List of input images (paths or PIL Images) - output_dir: Optional directory to save results - save_format: Format to use when saving images - return_binary: Whether to return binary data instead of PIL Images - - Returns: - List of processed images or binary data, or None if only saving - """ - results = [] - - for i, img in enumerate(images): - # Upscale the image - upscaled = self.upscale(img) - - # Save if output directory is provided - if output_dir: - # Extract filename if input is a path - if isinstance(img, str): - filename = os.path.basename(img) - base, _ = os.path.splitext(filename) - else: - base = f"upscaled_{i}" - - output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}") - self.save(upscaled, output_path, format=save_format) - - # Add to results based on return type - if return_binary: - results.append(self.convert_to_binary(upscaled, format=save_format)) - else: - results.append(upscaled) - - return results if (not output_dir or return_binary or not save_format) else None - # Example usage (can be removed) if __name__ == "__main__": # Initialize inference with a model path - inferencer = aiuNNInference("path/to/model", precision="bf16") + inferencer = aiuNNInference("path/to/model") # Upscale a single image upscaled_image = inferencer.upscale("input_image.jpg") @@ -217,10 +174,4 @@ if __name__ == "__main__": # Convert to binary binary_data = inferencer.convert_to_binary(upscaled_image) - - # Process a batch of images - inferencer.process_batch( - ["image1.jpg", "image2.jpg"], - output_dir="output_folder", - save_format="PNG" - ) \ No newline at end of file + \ No newline at end of file diff --git a/src/aiunn/upsampler/aiunn.py b/src/aiunn/upsampler/aiunn.py index ecb3c21..71c77f3 100644 --- a/src/aiunn/upsampler/aiunn.py +++ b/src/aiunn/upsampler/aiunn.py @@ -2,19 +2,19 @@ import os import torch import torch.nn as nn import warnings -from aiia.model.Model import AIIA, AIIAConfig, AIIABase +from aiia.model.Model import AIIAConfig, AIIABase +from transformers import PreTrainedModel from .config import aiuNNConfig import warnings -class aiuNN(AIIA): - def __init__(self, base_model: AIIA, config:aiuNNConfig): - super().__init__(base_model.config) - self.base_model = base_model - +class aiuNN(PreTrainedModel): + config_class = aiuNNConfig + def __init__(self, config: aiuNNConfig): + super().__init__(config) # Pass the unified base configuration using the new parameter. self.config = config - + # Enhanced approach scale_factor = self.config.upsample_scale out_channels = self.base_model.config.num_channels * (scale_factor ** 2) @@ -26,118 +26,18 @@ class aiuNN(AIIA): ) self.pixel_shuffle = nn.PixelShuffle(scale_factor) - + def load_base_model(self, base_model: PreTrainedModel): + self.base_model = base_model + def forward(self, x): + if self.base_model is None: + raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.") x = self.base_model(x) # Get base features x = self.pixel_shuffle_conv(x) # Expand channels for shuffling x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions return x - @classmethod - def load(cls, path, precision: str = None, **kwargs): - """ - Load a aiuNN model from disk with automatic detection of base model type. - - Args: - path (str): Directory containing the stored configuration and model parameters. - precision (str, optional): Desired precision for the model's parameters. - **kwargs: Additional keyword arguments to override configuration parameters. - - Returns: - An instance of aiuNN with loaded weights. - """ - # Load the configuration - config = aiuNNConfig.load(path) - - # Determine the device - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # Load the state dictionary - state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device) - - # Import all model types - from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive - - # Helper function to detect base class type from key patterns - def detect_base_class_type(keys_prefix): - if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()): - return AIIABaseShared - else: - return AIIABase - - # Detect base model type - base_model = None - - # Check for AIIAmoe with multiple experts - if any("base_model.experts" in key for key in state_dict.keys()): - # Count the number of experts - max_expert_idx = -1 - for key in state_dict.keys(): - if "base_model.experts." in key: - try: - parts = key.split("base_model.experts.")[1].split(".") - expert_idx = int(parts[0]) - max_expert_idx = max(max_expert_idx, expert_idx) - except (ValueError, IndexError): - pass - - if max_expert_idx >= 0: - # Determine the type of base_cnn each expert is using - base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn") - - # Create AIIAmoe with the detected expert count and base class - base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs) - - # Check for AIIAchunked or AIIArecursive - elif any("base_model.chunked_cnn" in key for key in state_dict.keys()): - if any("recursion_depth" in key for key in state_dict.keys()): - # This is an AIIArecursive model - base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn") - base_model = AIIArecursive(config, base_class=base_class, **kwargs) - else: - # This is an AIIAchunked model - base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn") - base_model = AIIAchunked(config, base_class=base_class, **kwargs) - - # Check for AIIAExpert - elif any("base_model.base_cnn" in key for key in state_dict.keys()): - # Determine which base class the expert is using - base_class = detect_base_class_type("base_model.base_cnn") - base_model = AIIAExpert(config, base_class=base_class, **kwargs) - - # If none of the above, use AIIABase or AIIABaseShared directly - else: - base_class = detect_base_class_type("base_model") - base_model = base_class(config, **kwargs) - - # Create the aiuNN model with the detected base model - model = cls(base_model, config=base_model.config) - - # Handle precision conversion - dtype = None - if precision is not None: - if precision.lower() == 'fp16': - dtype = torch.float16 - elif precision.lower() == 'bf16': - if device == 'cuda' and not torch.cuda.is_bf16_supported(): - warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.") - dtype = torch.float16 - else: - dtype = torch.bfloat16 - else: - raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") - - if dtype is not None: - for key, param in state_dict.items(): - if torch.is_tensor(param): - state_dict[key] = param.to(dtype) - - # Load the state dict - model.load_state_dict(state_dict) - return model - - if __name__ == "__main__": from aiia import AIIABase, AIIAConfig @@ -146,11 +46,11 @@ if __name__ == "__main__": ai_config = aiuNNConfig() base_model = AIIABase(config) # Instantiate Upsampler from the base model (works correctly). - upsampler = aiuNN(base_model, config=ai_config) - + upsampler = aiuNN(config=ai_config) + upsampler.load_base_model(base_model) # Save the model (both configuration and weights). - upsampler.save("aiunn") + upsampler.save_pretrained("aiunn") # Now load using the overridden load method; this will load the complete model. - upsampler_loaded = aiuNN.load("aiunn", precision="bf16") + upsampler_loaded = aiuNN.from_pretrained("aiunn") print("Updated configuration:", upsampler_loaded.config.__dict__) diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 39dea9a..5583d75 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -21,9 +21,8 @@ def real_model(tmp_path): base_model = AIIABase(config) # Make sure aiuNN is properly configured with all required attributes - upsampler = aiuNN(base_model, config=ai_config) - # Ensure the upsample attribute is properly set if needed - # upsampler.upsample = ... # Add any necessary initialization + upsampler = aiuNN(config=ai_config) + upsampler.load_base_model(base_model) # Save the model and config to temporary directory save_path = str(model_dir / "save") @@ -40,10 +39,10 @@ def real_model(tmp_path): json.dump(config_data, f) # Save model - upsampler.save(save_path) + upsampler.save_pretrained(save_path) # Load model in inference mode - inference_model = aiuNNInference(model_path=save_path, precision='fp16', device='cpu') + inference_model = aiuNNInference(model_path=save_path, device='cpu') return inference_model @@ -88,12 +87,3 @@ def test_convert_to_binary(inference): result = inference.convert_to_binary(test_image) assert isinstance(result, bytes) assert len(result) > 0 - -def test_process_batch(inference): - # Create test images - test_array = np.zeros((100, 100, 3), dtype=np.uint8) - test_images = [Image.fromarray(test_array) for _ in range(2)] - - results = inference.process_batch(test_images) - assert len(results) == 2 - assert all(isinstance(img, Image.Image) for img in results) \ No newline at end of file diff --git a/tests/upsampler/test_aiunn.py b/tests/upsampler/test_aiunn.py index aae0813..cdf11bc 100644 --- a/tests/upsampler/test_aiunn.py +++ b/tests/upsampler/test_aiunn.py @@ -10,39 +10,21 @@ def test_save_and_load_model(): config = AIIAConfig() ai_config = aiuNNConfig() base_model = AIIABase(config) - upsampler = aiuNN(base_model, config=ai_config) - + upsampler = aiuNN(config=ai_config) + upsampler.load_base_model(base_model) # Save the model save_path = os.path.join(tmpdirname, "model") - upsampler.save(save_path) + upsampler.save_pretrained(save_path) # Load the model - loaded_upsampler = aiuNN.load(save_path) + loaded_upsampler = aiuNN.from_pretrained(save_path) # Verify that the loaded model is the same as the original model assert isinstance(loaded_upsampler, aiuNN) - assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__ + assert loaded_upsampler.config.hidden_size == upsampler.config.hidden_size + assert loaded_upsampler.config._activation_function == upsampler.config._activation_function + assert loaded_upsampler.config.architectures == upsampler.config.architectures -def test_save_and_load_model_with_precision(): - # Create a temporary directory to save the model - with tempfile.TemporaryDirectory() as tmpdirname: - # Create configurations and build a base model - config = AIIAConfig() - ai_config = aiuNNConfig() - base_model = AIIABase(config) - upsampler = aiuNN(base_model, config=ai_config) - - # Save the model - save_path = os.path.join(tmpdirname, "model") - upsampler.save(save_path) - - # Load the model with precision 'bf16' - loaded_upsampler = aiuNN.load(save_path, precision="bf16") - - # Verify that the loaded model is the same as the original model - assert isinstance(loaded_upsampler, aiuNN) - assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__ if __name__ == "__main__": test_save_and_load_model() - test_save_and_load_model_with_precision() \ No newline at end of file