Compare commits
3 Commits
68eb26fa6c
...
c92fa68e92
Author | SHA1 | Date |
---|---|---|
|
c92fa68e92 | |
|
8010605f44 | |
|
6d1fc4c88d |
|
@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
[project]
|
[project]
|
||||||
name = "aiunn"
|
name = "aiunn"
|
||||||
version = "0.3.0"
|
version = "0.4.0"
|
||||||
description = "Finetuner for image upscaling using AIIA"
|
description = "Finetuner for image upscaling using AIIA"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="aiunn",
|
name="aiunn",
|
||||||
version="0.3.0",
|
version="0.4.0",
|
||||||
packages=find_packages(where="src"),
|
packages=find_packages(where="src"),
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
|
|
@ -4,4 +4,4 @@ from .upsampler.aiunn import aiuNN
|
||||||
from .upsampler.config import aiuNNConfig
|
from .upsampler.config import aiuNNConfig
|
||||||
from .inference.inference import aiuNNInference
|
from .inference.inference import aiuNNInference
|
||||||
|
|
||||||
__version__ = "0.3.0"
|
__version__ = "0.4.0"
|
|
@ -1,39 +1,44 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import warnings
|
|
||||||
from aiia.model.Model import AIIAConfig, AIIABase
|
from aiia.model.Model import AIIAConfig, AIIABase
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from .config import aiuNNConfig
|
from .config import aiuNNConfig
|
||||||
import warnings
|
|
||||||
|
|
||||||
class aiuNN(PreTrainedModel):
|
class aiuNN(PreTrainedModel):
|
||||||
config_class = aiuNNConfig
|
config_class = aiuNNConfig
|
||||||
|
|
||||||
def __init__(self, config: aiuNNConfig, base_model: PreTrainedModel = None):
|
def __init__(self, config: aiuNNConfig, base_model: PreTrainedModel = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
# Copy base layers into aiuNN for self-containment and portability
|
||||||
if base_model is not None:
|
if base_model is not None:
|
||||||
# Only copy submodules if base_model is provided
|
if hasattr(base_model, 'cnn'):
|
||||||
self.base_layers = nn.Sequential(*[layer for layer in base_model.cnn])
|
self.base_layers = nn.Sequential(*[layer for layer in base_model.cnn])
|
||||||
|
elif hasattr(base_model, 'shared_layer') and hasattr(base_model, 'unique_layers'):
|
||||||
|
layers = [base_model.shared_layer, base_model.activation, base_model.max_pool]
|
||||||
|
for ul in base_model.unique_layers:
|
||||||
|
layers.extend([ul, base_model.activation, base_model.max_pool])
|
||||||
|
self.base_layers = nn.Sequential(*layers)
|
||||||
|
else:
|
||||||
|
self.base_layers = self._build_base_layers_from_config(config)
|
||||||
else:
|
else:
|
||||||
# At inference, modules will be loaded from state_dict
|
|
||||||
self.base_layers = self._build_base_layers_from_config(config)
|
self.base_layers = self._build_base_layers_from_config(config)
|
||||||
|
|
||||||
scale_factor = self.config.upsample_scale
|
# Bilinear upsampling head
|
||||||
out_channels = self.config.num_channels * (scale_factor ** 2)
|
self.upsample = nn.Upsample(
|
||||||
self.pixel_shuffle_conv = nn.Conv2d(
|
scale_factor=self.config.upsample_scale,
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
self.final_conv = nn.Conv2d(
|
||||||
in_channels=self.config.hidden_size,
|
in_channels=self.config.hidden_size,
|
||||||
out_channels=out_channels,
|
out_channels=self.config.num_channels,
|
||||||
kernel_size=self.config.kernel_size,
|
kernel_size=3,
|
||||||
padding=1
|
padding=1
|
||||||
)
|
)
|
||||||
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
|
||||||
|
|
||||||
def _build_base_layers_from_config(self, config):
|
def _build_base_layers_from_config(self, config):
|
||||||
"""
|
|
||||||
Reconstruct the base layers (e.g., CNN) using only the config.
|
|
||||||
This must match exactly how your base model builds its layers!
|
|
||||||
"""
|
|
||||||
layers = []
|
layers = []
|
||||||
in_channels = config.num_channels
|
in_channels = config.num_channels
|
||||||
for _ in range(config.num_hidden_layers):
|
for _ in range(config.num_hidden_layers):
|
||||||
|
@ -47,14 +52,12 @@ class aiuNN(PreTrainedModel):
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.base_layers is not None:
|
x = self.base_layers(x)
|
||||||
x = self.base_layers(x)
|
x = self.upsample(x)
|
||||||
x = self.pixel_shuffle_conv(x)
|
x = self.final_conv(x)
|
||||||
x = self.pixel_shuffle(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from aiia import AIIABase, AIIAConfig
|
from aiia import AIIABase, AIIAConfig
|
||||||
# Create a configuration and build a base model.
|
# Create a configuration and build a base model.
|
||||||
|
|
Loading…
Reference in New Issue