diff --git a/example.py b/example.py new file mode 100644 index 0000000..58470b0 --- /dev/null +++ b/example.py @@ -0,0 +1,106 @@ +from aiia import AIIABase +from aiunn import aiuNN +from aiunn import aiuNNTrainer +import pandas as pd +import io +import base64 +from PIL import Image, ImageFile +from torch.utils.data import Dataset +from torchvision import transforms + + + +class UpscaleDataset(Dataset): + def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000): + combined_df = pd.DataFrame() + for parquet_file in parquet_files: + # Load a subset from each parquet file + df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(samples_per_file) + combined_df = pd.concat([combined_df, df], ignore_index=True) + + # Validate rows (ensuring each value is bytes or str) + self.df = combined_df.apply(self._validate_row, axis=1) + self.transform = transform + self.failed_indices = set() + + def _validate_row(self, row): + for col in ['image_512', 'image_1024']: + if not isinstance(row[col], (bytes, str)): + raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") + return row + + def _decode_image(self, data): + try: + if isinstance(data, str): + return base64.b64decode(data) + elif isinstance(data, bytes): + return data + raise ValueError(f"Unsupported data type: {type(data)}") + except Exception as e: + raise RuntimeError(f"Decoding failed: {str(e)}") + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + # If previous call failed for this index, use a different index + if idx in self.failed_indices: + return self[(idx + 1) % len(self)] + try: + row = self.df.iloc[idx] + low_res_bytes = self._decode_image(row['image_512']) + high_res_bytes = self._decode_image(row['image_1024']) + ImageFile.LOAD_TRUNCATED_IMAGES = True + # Open image bytes with Pillow and convert to RGBA first + low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') + high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA') + + # Create a new RGB image with black background + low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0)) + high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0)) + + # Composite the original image over the black background + low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3]) + high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3]) + + # Now we have true 3-channel RGB images with transparent areas converted to black + low_res = low_res_rgb + high_res = high_res_rgb + + # Resize the images to reduce VRAM usage + low_res = low_res.resize((410, 410), Image.LANCZOS) + high_res = high_res.resize((820, 820), Image.LANCZOS) + + # If a transform is provided (e.g. conversion to Tensor), apply it + if self.transform: + low_res = self.transform(low_res) + high_res = self.transform(high_res) + return low_res, high_res + except Exception as e: + print(f"\nError at index {idx}: {str(e)}") + self.failed_indices.add(idx) + return self[(idx + 1) % len(self)] + + +if __name__ =="__main__": + # Load your base model and upscaler + pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" + base_model = AIIABase.load(pretrained_model_path, precision="bf16") + upscaler = aiuNN(base_model) + + # Create trainer with your dataset class + trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset) + + # Load data using parameters for your dataset + dataset_params = { + 'parquet_files': [ + "/root/training_data/vision-dataset/image_upscaler.parquet", + "/root/training_data/vision-dataset/image_vec_upscaler.parquet" + ], + 'transform': transforms.Compose([transforms.ToTensor()]), + 'samples_per_file': 5000 + } + trainer.load_data(dataset_params=dataset_params, batch_size=1) + + # Fine-tune the model + trainer.finetune(output_path="trained_models") \ No newline at end of file diff --git a/input.jpg b/input.jpg new file mode 100644 index 0000000..0426a63 Binary files /dev/null and b/input.jpg differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b0e5e10 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "aiunn" +version = "0.1.1" +description = "Finetuner for image upscaling using AIIA" +readme = "README.md" +requires-python = ">=3.10" +license = {file = "LICENSE"} +authors = [ + {name = "Falko Habel", email = "falko.habel@gmx.de"}, +] + +[project.urls] +"Homepage" = "https://gitea.fabelous.app/Machine-Learning/aiuNN" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8e47744 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch +aiia +pillow +torchvision +sklearn \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4a4a835 --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup, find_packages + +setup( + name="aiunn", + version="0.1.1", + packages=find_packages(where="src"), + package_dir={"": "src"}, + install_requires=[ + line.strip() + for line in open("requirements.txt") + if line.strip() and not line.startswith("#") + ], + python_requires=">=3.10", +) \ No newline at end of file diff --git a/src/aiunn/__init__.py b/src/aiunn/__init__.py new file mode 100644 index 0000000..a097c72 --- /dev/null +++ b/src/aiunn/__init__.py @@ -0,0 +1,6 @@ +from .finetune.trainer import aiuNNTrainer +from .upsampler.aiunn import aiuNN +from .upsampler.config import aiuNNConfig +from .inference.inference import aiuNNInference + +__version__ = "0.1.1" \ No newline at end of file diff --git a/src/aiunn/finetune/__init__.py b/src/aiunn/finetune/__init__.py new file mode 100644 index 0000000..33239b1 --- /dev/null +++ b/src/aiunn/finetune/__init__.py @@ -0,0 +1,3 @@ +from .trainer import aiuNNTrainer + +__all__ = ["aiuNNTrainer" ] \ No newline at end of file diff --git a/src/aiunn/finetune/trainer.py b/src/aiunn/finetune/trainer.py new file mode 100644 index 0000000..b94d57d --- /dev/null +++ b/src/aiunn/finetune/trainer.py @@ -0,0 +1,290 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import os +import csv +from torch.amp import autocast, GradScaler +from torch.utils.data import DataLoader +from tqdm import tqdm +from torch.utils.checkpoint import checkpoint +import gc +import time +import shutil + + +class EarlyStopping: + def __init__(self, patience=3, min_delta=0.001): + # Number of epochs with no significant improvement before stopping + # Minimum change in loss required to count as an improvement + self.patience = patience + self.min_delta = min_delta + self.best_loss = float('inf') + self.counter = 0 + self.early_stop = False + + def __call__(self, epoch_loss): + if epoch_loss < self.best_loss - self.min_delta: + self.best_loss = epoch_loss + self.counter = 0 + return True # Improved + else: + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + return False # Not improved + +class aiuNNTrainer: + def __init__(self, upscaler_model, dataset_class=None): + """ + Initialize the upscaler trainer + + Args: + upscaler_model: The model to fine-tune + dataset_class: The dataset class to use for loading data (optional) + """ + self.model = upscaler_model + self.dataset_class = dataset_class + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = self.model.to(self.device, memory_format=torch.channels_last) + self.criterion = nn.MSELoss() + self.optimizer = None + self.scaler = GradScaler() + self.best_loss = float('inf') + self.use_checkpointing = True + self.data_loader = None + self.validation_loader = None + self.log_dir = None + + def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None): + """ + Load data using either a custom dataset instance or the dataset class provided at initialization + + Args: + dataset_params (dict/list): Parameters to pass to the dataset class constructor + batch_size (int): Batch size for training + validation_split (float): Proportion of data to use for validation + custom_train_dataset: A pre-instantiated dataset to use for training (optional) + custom_val_dataset: A pre-instantiated dataset to use for validation (optional) + """ + # If custom datasets are provided directly, use them + if custom_train_dataset is not None: + train_dataset = custom_train_dataset + val_dataset = custom_val_dataset if custom_val_dataset is not None else None + else: + # Otherwise instantiate dataset using the class and parameters + if self.dataset_class is None: + raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.") + + # Create dataset instance + dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params}) + + # Split into train and validation sets + dataset_size = len(dataset) + val_size = int(validation_split * dataset_size) + train_size = dataset_size - val_size + + train_dataset, val_dataset = torch.utils.data.random_split( + dataset, [train_size, val_size] + ) + + # Create data loaders + self.data_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=True + ) + + if val_dataset is not None: + self.validation_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + pin_memory=True + ) + print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples") + else: + self.validation_loader = None + print(f"Loaded {len(train_dataset)} training samples (no validation set)") + + return self.data_loader, self.validation_loader + + def _setup_logging(self, output_path): + """Set up directory structure for logging and model checkpoints""" + timestamp = time.strftime("%Y%m%d-%H%M%S") + self.log_dir = os.path.join(output_path, f"training_run_{timestamp}") + os.makedirs(self.log_dir, exist_ok=True) + + # Create checkpoint directory + self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints") + os.makedirs(self.checkpoint_dir, exist_ok=True) + + # Set up CSV logging + self.csv_path = os.path.join(self.log_dir, 'training_log.csv') + with open(self.csv_path, mode='w', newline='') as file: + writer = csv.writer(file) + if self.validation_loader: + writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved']) + else: + writer.writerow(['Epoch', 'Train Loss', 'Improved']) + + def _evaluate(self): + """Evaluate the model on validation data""" + if self.validation_loader is None: + return 0.0 + + self.model.eval() + val_loss = 0.0 + + with torch.no_grad(): + for low_res, high_res in tqdm(self.validation_loader, desc="Validating"): + low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) + high_res = high_res.to(self.device, non_blocking=True) + + with autocast(device_type=self.device.type): + outputs = self.model(low_res) + loss = self.criterion(outputs, high_res) + + val_loss += loss.item() + + del low_res, high_res, outputs, loss + + self.model.train() + return val_loss + + def _save_checkpoint(self, epoch, is_best=False): + """Save model checkpoint""" + checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt") + best_model_path = os.path.join(self.log_dir, "best_model") + + # Save the model checkpoint + self.model.save(checkpoint_path) + + # If this is the best model so far, copy it to best_model + if is_best: + if os.path.exists(best_model_path): + shutil.rmtree(best_model_path) + self.model.save(best_model_path) + print(f"Saved new best model with loss: {self.best_loss:.6f}") + + def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): + """ + Finetune the upscaler model + + Args: + output_path (str): Directory to save models and logs + epochs (int): Maximum number of training epochs + lr (float): Learning rate + patience (int): Early stopping patience + min_delta (float): Minimum improvement for early stopping + """ + # Check if data is loaded + if self.data_loader is None: + raise ValueError("Data not loaded. Call load_data first.") + + # Setup optimizer + self.optimizer = optim.Adam(self.model.parameters(), lr=lr) + + # Set up logging + self._setup_logging(output_path) + + # Setup early stopping + early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) + + # Training loop + self.model.train() + + for epoch in range(epochs): + # Training phase + epoch_loss = 0.0 + progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}") + + for low_res, high_res in progress_bar: + # Move data to GPU with channels_last format where possible + low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) + high_res = high_res.to(self.device, non_blocking=True) + + self.optimizer.zero_grad() + + with autocast(device_type=self.device.type): + if self.use_checkpointing: + # Ensure the input tensor requires gradient so that checkpointing records the computation graph + low_res.requires_grad_() + outputs = checkpoint(self.model, low_res) + else: + outputs = self.model(low_res) + loss = self.criterion(outputs, high_res) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + epoch_loss += loss.item() + progress_bar.set_postfix({'loss': loss.item()}) + + # Optionally delete variables to free memory + del low_res, high_res, outputs, loss + + # Calculate average epoch loss + avg_train_loss = epoch_loss / len(self.data_loader) + + # Validation phase (if validation loader exists) + if self.validation_loader: + val_loss = self._evaluate() / len(self.validation_loader) + is_improved = val_loss < self.best_loss + if is_improved: + self.best_loss = val_loss + + # Log results + print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}") + with open(self.csv_path, mode='a', newline='') as file: + writer = csv.writer(file) + writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"]) + else: + # If no validation, use training loss for improvement tracking + is_improved = avg_train_loss < self.best_loss + if is_improved: + self.best_loss = avg_train_loss + + # Log results + print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}") + with open(self.csv_path, mode='a', newline='') as file: + writer = csv.writer(file) + writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"]) + + # Save checkpoint + self._save_checkpoint(epoch + 1, is_best=is_improved) + + # Perform garbage collection and clear GPU cache after each epoch + gc.collect() + torch.cuda.empty_cache() + + # Check early stopping + early_stopping(val_loss if self.validation_loader else avg_train_loss) + if early_stopping.early_stop: + print(f"Early stopping triggered at epoch {epoch + 1}") + break + + return self.best_loss + + def save(self, output_path=None): + """ + Save the best model to the specified path + + Args: + output_path (str, optional): Path to save the model. If None, uses the best model from training. + """ + if output_path is None and self.log_dir is not None: + best_model_path = os.path.join(self.log_dir, "best_model") + if os.path.exists(best_model_path): + print(f"Best model already saved at {best_model_path}") + return best_model_path + else: + output_path = os.path.join(self.log_dir, "final_model") + + if output_path is None: + raise ValueError("No output path specified and no training has been done yet.") + + self.model.save(output_path) + print(f"Model saved to {output_path}") + return output_path \ No newline at end of file diff --git a/src/aiunn/inference/__init__.py b/src/aiunn/inference/__init__.py new file mode 100644 index 0000000..798de24 --- /dev/null +++ b/src/aiunn/inference/__init__.py @@ -0,0 +1,3 @@ +from .inference import aiuNNInference + +__all__ = ["aiuNNInference"] \ No newline at end of file diff --git a/src/aiunn/inference/inference.py b/src/aiunn/inference/inference.py new file mode 100644 index 0000000..d288931 --- /dev/null +++ b/src/aiunn/inference/inference.py @@ -0,0 +1,226 @@ +import os +import torch +import numpy as np +from PIL import Image +import io +from typing import Union, Optional, Tuple, List +from ..upsampler.aiunn import aiuNN + + +class aiuNNInference: + """ + Inference class for aiuNN upsampling model. + Handles model loading, image upscaling, and output processing. + """ + def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None): + """ + Initialize the inference class by loading the aiuNN model. + + Args: + model_path: Path to the saved model directory + precision: Optional precision setting ('fp16', 'bf16', or None for default) + device: Optional device specification ('cuda', 'cpu', or None for auto-detection) + """ + + + # Set device + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + + # Load the model with specified precision + self.model = aiuNN.load(model_path, precision=precision) + self.model.to(self.device) + self.model.eval() + + # Store configuration for reference + self.config = self.model.config + + def preprocess_image(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> torch.Tensor: + """ + Preprocess the input image to match model requirements. + + Args: + image: Input image as file path, PIL Image, numpy array, or torch tensor + + Returns: + Preprocessed tensor ready for model input + """ + # Handle different input types + if isinstance(image, str): + # Load from file path + image = Image.open(image).convert('RGB') + + if isinstance(image, Image.Image): + # Convert PIL Image to tensor + image = np.array(image) + image = image.transpose(2, 0, 1) # HWC to CHW + image = torch.from_numpy(image).float() + + if isinstance(image, np.ndarray): + # Convert numpy array to tensor + if image.shape[0] == 3: + # Already in CHW format + pass + elif image.shape[-1] == 3: + # HWC to CHW format + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).float() + + # Normalize to [0, 1] range if needed + if image.max() > 1.0: + image = image / 255.0 + + # Add batch dimension if not present + if len(image.shape) == 3: + image = image.unsqueeze(0) + + # Move to device + image = image.to(self.device) + + return image + + def postprocess_tensor(self, tensor: torch.Tensor) -> Image.Image: + """ + Convert output tensor to PIL Image. + + Args: + tensor: Output tensor from model + + Returns: + Processed PIL Image + """ + # Move to CPU and convert to numpy + output = tensor.detach().cpu().squeeze(0).numpy() + + # Ensure proper range [0, 255] + output = np.clip(output * 255, 0, 255).astype(np.uint8) + + # Convert from CHW to HWC for PIL + output = output.transpose(1, 2, 0) + + # Create PIL Image + return Image.fromarray(output) + + @torch.no_grad() + def upscale(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> Image.Image: + """ + Upscale an image using the aiuNN model. + + Args: + image: Input image to upscale + + Returns: + Upscaled image as PIL Image + """ + # Preprocess input + input_tensor = self.preprocess_image(image) + + # Run inference + output_tensor = self.model(input_tensor) + + # Postprocess output + upscaled_image = self.postprocess_tensor(output_tensor) + + return upscaled_image + + def save(self, image: Image.Image, output_path: str, format: Optional[str] = None) -> None: + """ + Save the upscaled image to a file. + + Args: + image: PIL Image to save + output_path: Path where the image should be saved + format: Optional format override (e.g., 'PNG', 'JPEG') + """ + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) + + # Save the image + image.save(output_path, format=format) + + def convert_to_binary(self, image: Image.Image, format: str = 'PNG') -> bytes: + """ + Convert the image to binary data. + + Args: + image: PIL Image to convert + format: Image format to use for binary conversion + + Returns: + Binary representation of the image + """ + # Use BytesIO to convert to binary + binary_output = io.BytesIO() + image.save(binary_output, format=format) + + # Get the binary data + binary_data = binary_output.getvalue() + + return binary_data + + def process_batch(self, + images: List[Union[str, Image.Image]], + output_dir: Optional[str] = None, + save_format: str = 'PNG', + return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]: + """ + Process multiple images in batch. + + Args: + images: List of input images (paths or PIL Images) + output_dir: Optional directory to save results + save_format: Format to use when saving images + return_binary: Whether to return binary data instead of PIL Images + + Returns: + List of processed images or binary data, or None if only saving + """ + results = [] + + for i, img in enumerate(images): + # Upscale the image + upscaled = self.upscale(img) + + # Save if output directory is provided + if output_dir: + # Extract filename if input is a path + if isinstance(img, str): + filename = os.path.basename(img) + base, _ = os.path.splitext(filename) + else: + base = f"upscaled_{i}" + + output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}") + self.save(upscaled, output_path, format=save_format) + + # Add to results based on return type + if return_binary: + results.append(self.convert_to_binary(upscaled, format=save_format)) + else: + results.append(upscaled) + + return results if (not output_dir or return_binary or not save_format) else None + + +# Example usage (can be removed) +if __name__ == "__main__": + # Initialize inference with a model path + inferencer = aiuNNInference("path/to/model", precision="bf16") + + # Upscale a single image + upscaled_image = inferencer.upscale("input_image.jpg") + + # Save the result + inferencer.save(upscaled_image, "output_image.png") + + # Convert to binary + binary_data = inferencer.convert_to_binary(upscaled_image) + + # Process a batch of images + inferencer.process_batch( + ["image1.jpg", "image2.jpg"], + output_dir="output_folder", + save_format="PNG" + ) \ No newline at end of file diff --git a/src/aiunn/upsampler/__init__.py b/src/aiunn/upsampler/__init__.py new file mode 100644 index 0000000..e179503 --- /dev/null +++ b/src/aiunn/upsampler/__init__.py @@ -0,0 +1,5 @@ +from .aiunn import aiuNN +from .config import aiuNNConfig + + +__all__ = ["aiuNN", "aiuNNConfig"] \ No newline at end of file diff --git a/src/aiunn/upsampler/aiunn.py b/src/aiunn/upsampler/aiunn.py new file mode 100644 index 0000000..ca7ba2b --- /dev/null +++ b/src/aiunn/upsampler/aiunn.py @@ -0,0 +1,82 @@ +import os +import torch +import torch.nn as nn +import warnings +from aiia import AIIA, AIIAConfig, AIIABase +from .config import aiuNNConfig +import warnings + + +class aiuNN(AIIA): + def __init__(self, base_model: AIIABase): + super().__init__(base_model.config) + self.base_model = base_model + + # Pass the unified base configuration using the new parameter. + self.config = aiuNNConfig(base_config=base_model.config) + + self.upsample = nn.Upsample( + scale_factor=self.config.upsample_scale, + mode=self.config.upsample_mode, + align_corners=self.config.upsample_align_corners + ) + # Conversion layer: change from hidden size channels to 3 channels. + self.to_rgb = nn.Conv2d( + in_channels=self.base_model.config.hidden_size, + out_channels=3, + kernel_size=1 + ) + + + def forward(self, x): + x = self.base_model(x) + x = self.upsample(x) + x = self.to_rgb(x) # Ensures output has 3 channels. + return x + + @classmethod + def load(cls, path, precision: str = None): + # Load the configuration from disk. + config = AIIAConfig.load(path) + # Reconstruct the base model from the loaded configuration. + base_model = AIIABase(config) + # Instantiate the Upsampler using the proper base model. + upsampler = cls(base_model) + + # Load state dict and handle precision conversion if needed. + device = 'cuda' if torch.cuda.is_available() else 'cpu' + state_dict = torch.load(f"{path}/model.pth", map_location=device) + if precision is not None: + if precision.lower() == 'fp16': + dtype = torch.float16 + elif precision.lower() == 'bf16': + if device == 'cuda' and not torch.cuda.is_bf16_supported(): + warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.") + dtype = torch.float16 + else: + dtype = torch.bfloat16 + else: + raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") + + for key, param in state_dict.items(): + if torch.is_tensor(param): + state_dict[key] = param.to(dtype) + upsampler.load_state_dict(state_dict) + return upsampler + + + +if __name__ == "__main__": + from aiia import AIIABase, AIIAConfig + # Create a configuration and build a base model. + config = AIIAConfig() + base_model = AIIABase(config) + # Instantiate Upsampler from the base model (works correctly). + upsampler = aiuNN(base_model) + + # Save the model (both configuration and weights). + upsampler.save("hehe") + + # Now load using the overridden load method; this will load the complete model. + upsampler_loaded = aiuNN.load("hehe", precision="bf16") + print("Updated configuration:", upsampler_loaded.config.__dict__) diff --git a/src/aiunn/upsampler/config.py b/src/aiunn/upsampler/config.py new file mode 100644 index 0000000..b56699b --- /dev/null +++ b/src/aiunn/upsampler/config.py @@ -0,0 +1,50 @@ +from aiia import AIIAConfig + + +class aiuNNConfig(AIIAConfig): + def __init__( + self, + base_config=None, + upsample_scale: int = 2, + upsample_mode: str = 'bilinear', + upsample_align_corners: bool = False, + layers=None, + **kwargs + ): + # Start with a single configuration dictionary. + config_data = {} + if base_config is not None: + # If base_config is an object with a to_dict method, use it. + if hasattr(base_config, "to_dict"): + config_data.update(base_config.to_dict()) + elif isinstance(base_config, dict): + config_data.update(base_config) + + # Update with any additional keyword arguments (if needed). + config_data.update(kwargs) + + # Initialize base AIIAConfig with a single merged configuration. + super().__init__(**config_data) + + # Upsampler-specific parameters. + self.upsample_scale = upsample_scale + self.upsample_mode = upsample_mode + self.upsample_align_corners = upsample_align_corners + + # Use layers from the argument or initialize an empty list. + self.layers = layers if layers is not None else [] + + # Add the upsample layer details only once. + self.add_upsample_layer() + + def add_upsample_layer(self): + upsample_layer = { + 'name': 'Upsample', + 'type': 'nn.Upsample', + 'scale_factor': self.upsample_scale, + 'mode': self.upsample_mode, + 'align_corners': self.upsample_align_corners + } + # Append the layer only if it isn’t already present. + if not any(layer.get('name') == 'Upsample' for layer in self.layers): + self.layers.append(upsample_layer)