Compare commits

..

5 Commits

5 changed files with 16 additions and 13 deletions

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "aiunn" name = "aiunn"
version = "0.1.1" version = "0.1.2"
description = "Finetuner for image upscaling using AIIA" description = "Finetuner for image upscaling using AIIA"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -2,4 +2,5 @@ torch
aiia aiia
pillow pillow
torchvision torchvision
sklearn sklearn
https://gitea.fabelous.app/Machine-Learning/AIIA.git

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name="aiunn", name="aiunn",
version="0.1.1", version="0.1.2",
packages=find_packages(where="src"), packages=find_packages(where="src"),
package_dir={"": "src"}, package_dir={"": "src"},
install_requires=[ install_requires=[

View File

@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
from .upsampler.config import aiuNNConfig from .upsampler.config import aiuNNConfig
from .inference.inference import aiuNNInference from .inference.inference import aiuNNInference
__version__ = "0.1.1" __version__ = "0.1.2"

View File

@ -2,18 +2,18 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import warnings import warnings
from aiia import AIIA, AIIAConfig, AIIABase from aiia.model.Model import AIIA, AIIAConfig, AIIABase
from .config import aiuNNConfig from .config import aiuNNConfig
import warnings import warnings
class aiuNN(AIIA): class aiuNN(AIIA):
def __init__(self, base_model: AIIABase): def __init__(self, base_model: AIIA, config:aiuNNConfig):
super().__init__(base_model.config) super().__init__(base_model.config)
self.base_model = base_model self.base_model = base_model
# Pass the unified base configuration using the new parameter. # Pass the unified base configuration using the new parameter.
self.config = aiuNNConfig(base_config=base_model.config) self.config = config
# Enhanced approach # Enhanced approach
scale_factor = self.config.upsample_scale scale_factor = self.config.upsample_scale
@ -28,11 +28,12 @@ class aiuNN(AIIA):
def forward(self, x): def forward(self, x):
x = self.base_model(x) x = self.base_model(x) # Get base features
x = self.upsample(x) x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.to_rgb(x) # Ensures output has 3 channels. x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x return x
@classmethod @classmethod
def load(cls, path, precision: str = None, **kwargs): def load(cls, path, precision: str = None, **kwargs):
""" """
@ -56,7 +57,7 @@ class aiuNN(AIIA):
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device) state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
# Import all model types # Import all model types
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
# Helper function to detect base class type from key patterns # Helper function to detect base class type from key patterns
def detect_base_class_type(keys_prefix): def detect_base_class_type(keys_prefix):
@ -111,7 +112,7 @@ class aiuNN(AIIA):
base_model = base_class(config, **kwargs) base_model = base_class(config, **kwargs)
# Create the aiuNN model with the detected base model # Create the aiuNN model with the detected base model
model = cls(base_model) model = cls(base_model, config=base_model.config)
# Handle precision conversion # Handle precision conversion
dtype = None dtype = None
@ -142,9 +143,10 @@ if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model. # Create a configuration and build a base model.
config = AIIAConfig() config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config) base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly). # Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(base_model) upsampler = aiuNN(base_model, config=ai_config)
# Save the model (both configuration and weights). # Save the model (both configuration and weights).
upsampler.save("hehe") upsampler.save("hehe")