feat/model_fix #19

Merged
Fabel merged 2 commits from feat/model_fix into main 2025-06-02 16:34:48 +00:00
4 changed files with 9 additions and 9 deletions

View File

@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
name = "aiunn" name = "aiunn"
version = "0.2.2" version = "0.2.3"
description = "Finetuner for image upscaling using AIIA" description = "Finetuner for image upscaling using AIIA"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name="aiunn", name="aiunn",
version="0.2.2", version="0.2.3",
packages=find_packages(where="src"), packages=find_packages(where="src"),
package_dir={"": "src"}, package_dir={"": "src"},
install_requires=[ install_requires=[

View File

@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
from .upsampler.config import aiuNNConfig from .upsampler.config import aiuNNConfig
from .inference.inference import aiuNNInference from .inference.inference import aiuNNInference
__version__ = "0.2.2" __version__ = "0.2.3"

View File

@ -17,24 +17,24 @@ class aiuNN(PreTrainedModel):
# Enhanced approach # Enhanced approach
scale_factor = self.config.upsample_scale scale_factor = self.config.upsample_scale
out_channels = self.base_model.config.num_channels * (scale_factor ** 2) out_channels = self.aiia_model.config.num_channels * (scale_factor ** 2)
self.pixel_shuffle_conv = nn.Conv2d( self.pixel_shuffle_conv = nn.Conv2d(
in_channels=self.base_model.config.hidden_size, in_channels=self.aiia_model.config.hidden_size,
out_channels=out_channels, out_channels=out_channels,
kernel_size=self.base_model.config.kernel_size, kernel_size=self.aiia_model.config.kernel_size,
padding=1 padding=1
) )
self.pixel_shuffle = nn.PixelShuffle(scale_factor) self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def load_base_model(self, base_model: PreTrainedModel): def load_base_model(self, base_model: PreTrainedModel):
self.base_model = base_model self.aiia_model = base_model
def forward(self, x): def forward(self, x):
if self.base_model is None: if self.aiia_model is None:
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.") raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict # Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
base_output = self.base_model(x) base_output = self.aiia_model(x)
if isinstance(base_output, tuple): if isinstance(base_output, tuple):
x = base_output[0] x = base_output[0]
elif isinstance(base_output, dict): elif isinstance(base_output, dict):