Compare commits

...

5 Commits

5 changed files with 16 additions and 13 deletions

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "aiunn"
version = "0.1.1"
version = "0.1.2"
description = "Finetuner for image upscaling using AIIA"
readme = "README.md"
requires-python = ">=3.10"

View File

@ -2,4 +2,5 @@ torch
aiia
pillow
torchvision
sklearn
sklearn
https://gitea.fabelous.app/Machine-Learning/AIIA.git

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name="aiunn",
version="0.1.1",
version="0.1.2",
packages=find_packages(where="src"),
package_dir={"": "src"},
install_requires=[

View File

@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
from .upsampler.config import aiuNNConfig
from .inference.inference import aiuNNInference
__version__ = "0.1.1"
__version__ = "0.1.2"

View File

@ -2,18 +2,18 @@ import os
import torch
import torch.nn as nn
import warnings
from aiia import AIIA, AIIAConfig, AIIABase
from aiia.model.Model import AIIA, AIIAConfig, AIIABase
from .config import aiuNNConfig
import warnings
class aiuNN(AIIA):
def __init__(self, base_model: AIIABase):
def __init__(self, base_model: AIIA, config:aiuNNConfig):
super().__init__(base_model.config)
self.base_model = base_model
# Pass the unified base configuration using the new parameter.
self.config = aiuNNConfig(base_config=base_model.config)
self.config = config
# Enhanced approach
scale_factor = self.config.upsample_scale
@ -28,11 +28,12 @@ class aiuNN(AIIA):
def forward(self, x):
x = self.base_model(x)
x = self.upsample(x)
x = self.to_rgb(x) # Ensures output has 3 channels.
x = self.base_model(x) # Get base features
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x
@classmethod
def load(cls, path, precision: str = None, **kwargs):
"""
@ -56,7 +57,7 @@ class aiuNN(AIIA):
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
# Import all model types
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
# Helper function to detect base class type from key patterns
def detect_base_class_type(keys_prefix):
@ -111,7 +112,7 @@ class aiuNN(AIIA):
base_model = base_class(config, **kwargs)
# Create the aiuNN model with the detected base model
model = cls(base_model)
model = cls(base_model, config=base_model.config)
# Handle precision conversion
dtype = None
@ -142,9 +143,10 @@ if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model.
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(base_model)
upsampler = aiuNN(base_model, config=ai_config)
# Save the model (both configuration and weights).
upsampler.save("hehe")