feat/save_fix #25

Merged
Fabel merged 2 commits from feat/save_fix into main 2025-07-03 08:19:31 +00:00
4 changed files with 28 additions and 25 deletions

View File

@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
name = "aiunn" name = "aiunn"
version = "0.3.0" version = "0.4.0"
description = "Finetuner for image upscaling using AIIA" description = "Finetuner for image upscaling using AIIA"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name="aiunn", name="aiunn",
version="0.3.0", version="0.4.0",
packages=find_packages(where="src"), packages=find_packages(where="src"),
package_dir={"": "src"}, package_dir={"": "src"},
install_requires=[ install_requires=[

View File

@ -4,4 +4,4 @@ from .upsampler.aiunn import aiuNN
from .upsampler.config import aiuNNConfig from .upsampler.config import aiuNNConfig
from .inference.inference import aiuNNInference from .inference.inference import aiuNNInference
__version__ = "0.3.0" __version__ = "0.4.0"

View File

@ -1,39 +1,44 @@
import os
import torch
import torch.nn as nn import torch.nn as nn
import warnings
from aiia.model.Model import AIIAConfig, AIIABase from aiia.model.Model import AIIAConfig, AIIABase
from transformers import PreTrainedModel from transformers import PreTrainedModel
from .config import aiuNNConfig from .config import aiuNNConfig
import warnings
class aiuNN(PreTrainedModel): class aiuNN(PreTrainedModel):
config_class = aiuNNConfig config_class = aiuNNConfig
def __init__(self, config: aiuNNConfig, base_model: PreTrainedModel = None): def __init__(self, config: aiuNNConfig, base_model: PreTrainedModel = None):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Copy base layers into aiuNN for self-containment and portability
if base_model is not None: if base_model is not None:
# Only copy submodules if base_model is provided if hasattr(base_model, 'cnn'):
self.base_layers = nn.Sequential(*[layer for layer in base_model.cnn]) self.base_layers = nn.Sequential(*[layer for layer in base_model.cnn])
elif hasattr(base_model, 'shared_layer') and hasattr(base_model, 'unique_layers'):
layers = [base_model.shared_layer, base_model.activation, base_model.max_pool]
for ul in base_model.unique_layers:
layers.extend([ul, base_model.activation, base_model.max_pool])
self.base_layers = nn.Sequential(*layers)
else:
self.base_layers = self._build_base_layers_from_config(config)
else: else:
# At inference, modules will be loaded from state_dict
self.base_layers = self._build_base_layers_from_config(config) self.base_layers = self._build_base_layers_from_config(config)
scale_factor = self.config.upsample_scale # Bilinear upsampling head
out_channels = self.config.num_channels * (scale_factor ** 2) self.upsample = nn.Upsample(
self.pixel_shuffle_conv = nn.Conv2d( scale_factor=self.config.upsample_scale,
mode='bilinear',
align_corners=False
)
self.final_conv = nn.Conv2d(
in_channels=self.config.hidden_size, in_channels=self.config.hidden_size,
out_channels=out_channels, out_channels=self.config.num_channels,
kernel_size=self.config.kernel_size, kernel_size=3,
padding=1 padding=1
) )
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def _build_base_layers_from_config(self, config): def _build_base_layers_from_config(self, config):
"""
Reconstruct the base layers (e.g., CNN) using only the config.
This must match exactly how your base model builds its layers!
"""
layers = [] layers = []
in_channels = config.num_channels in_channels = config.num_channels
for _ in range(config.num_hidden_layers): for _ in range(config.num_hidden_layers):
@ -47,14 +52,12 @@ class aiuNN(PreTrainedModel):
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
if self.base_layers is not None: x = self.base_layers(x)
x = self.base_layers(x) x = self.upsample(x)
x = self.pixel_shuffle_conv(x) x = self.final_conv(x)
x = self.pixel_shuffle(x)
return x return x
if __name__ == "__main__": if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model. # Create a configuration and build a base model.