Compare commits
5 Commits
25f4225baf
...
4501b6e34a
Author | SHA1 | Date |
---|---|---|
|
4501b6e34a | |
|
de0da5de82 | |
|
68a27f00c1 | |
|
399a7c0f69 | |
|
5c668e3c7b |
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "aiunn"
|
name = "aiunn"
|
||||||
version = "0.1.1"
|
version = "0.1.2"
|
||||||
description = "Finetuner for image upscaling using AIIA"
|
description = "Finetuner for image upscaling using AIIA"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|
|
@ -2,4 +2,5 @@ torch
|
||||||
aiia
|
aiia
|
||||||
pillow
|
pillow
|
||||||
torchvision
|
torchvision
|
||||||
sklearn
|
sklearn
|
||||||
|
https://gitea.fabelous.app/Machine-Learning/AIIA.git
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="aiunn",
|
name="aiunn",
|
||||||
version="0.1.1",
|
version="0.1.2",
|
||||||
packages=find_packages(where="src"),
|
packages=find_packages(where="src"),
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
|
|
@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
|
||||||
from .upsampler.config import aiuNNConfig
|
from .upsampler.config import aiuNNConfig
|
||||||
from .inference.inference import aiuNNInference
|
from .inference.inference import aiuNNInference
|
||||||
|
|
||||||
__version__ = "0.1.1"
|
__version__ = "0.1.2"
|
|
@ -2,18 +2,18 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import warnings
|
import warnings
|
||||||
from aiia import AIIA, AIIAConfig, AIIABase
|
from aiia.model.Model import AIIA, AIIAConfig, AIIABase
|
||||||
from .config import aiuNNConfig
|
from .config import aiuNNConfig
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class aiuNN(AIIA):
|
class aiuNN(AIIA):
|
||||||
def __init__(self, base_model: AIIABase):
|
def __init__(self, base_model: AIIA, config:aiuNNConfig):
|
||||||
super().__init__(base_model.config)
|
super().__init__(base_model.config)
|
||||||
self.base_model = base_model
|
self.base_model = base_model
|
||||||
|
|
||||||
# Pass the unified base configuration using the new parameter.
|
# Pass the unified base configuration using the new parameter.
|
||||||
self.config = aiuNNConfig(base_config=base_model.config)
|
self.config = config
|
||||||
|
|
||||||
# Enhanced approach
|
# Enhanced approach
|
||||||
scale_factor = self.config.upsample_scale
|
scale_factor = self.config.upsample_scale
|
||||||
|
@ -28,11 +28,12 @@ class aiuNN(AIIA):
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.base_model(x)
|
x = self.base_model(x) # Get base features
|
||||||
x = self.upsample(x)
|
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
||||||
x = self.to_rgb(x) # Ensures output has 3 channels.
|
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path, precision: str = None, **kwargs):
|
def load(cls, path, precision: str = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -56,7 +57,7 @@ class aiuNN(AIIA):
|
||||||
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
|
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
|
||||||
|
|
||||||
# Import all model types
|
# Import all model types
|
||||||
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
|
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
|
||||||
|
|
||||||
# Helper function to detect base class type from key patterns
|
# Helper function to detect base class type from key patterns
|
||||||
def detect_base_class_type(keys_prefix):
|
def detect_base_class_type(keys_prefix):
|
||||||
|
@ -111,7 +112,7 @@ class aiuNN(AIIA):
|
||||||
base_model = base_class(config, **kwargs)
|
base_model = base_class(config, **kwargs)
|
||||||
|
|
||||||
# Create the aiuNN model with the detected base model
|
# Create the aiuNN model with the detected base model
|
||||||
model = cls(base_model)
|
model = cls(base_model, config=base_model.config)
|
||||||
|
|
||||||
# Handle precision conversion
|
# Handle precision conversion
|
||||||
dtype = None
|
dtype = None
|
||||||
|
@ -142,9 +143,10 @@ if __name__ == "__main__":
|
||||||
from aiia import AIIABase, AIIAConfig
|
from aiia import AIIABase, AIIAConfig
|
||||||
# Create a configuration and build a base model.
|
# Create a configuration and build a base model.
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
|
ai_config = aiuNNConfig()
|
||||||
base_model = AIIABase(config)
|
base_model = AIIABase(config)
|
||||||
# Instantiate Upsampler from the base model (works correctly).
|
# Instantiate Upsampler from the base model (works correctly).
|
||||||
upsampler = aiuNN(base_model)
|
upsampler = aiuNN(base_model, config=ai_config)
|
||||||
|
|
||||||
# Save the model (both configuration and weights).
|
# Save the model (both configuration and weights).
|
||||||
upsampler.save("hehe")
|
upsampler.save("hehe")
|
||||||
|
|
Loading…
Reference in New Issue