diff --git a/pyproject.toml b/pyproject.toml index 68125b7..873f10f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"] build-backend = "setuptools.build_meta" [project] name = "aiunn" -version = "0.2.2" +version = "0.2.3" description = "Finetuner for image upscaling using AIIA" readme = "README.md" requires-python = ">=3.10" diff --git a/setup.py b/setup.py index 57c0fa6..6e22059 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name="aiunn", - version="0.2.2", + version="0.2.3", packages=find_packages(where="src"), package_dir={"": "src"}, install_requires=[ diff --git a/src/aiunn/__init__.py b/src/aiunn/__init__.py index dd7a3cd..05c37ab 100644 --- a/src/aiunn/__init__.py +++ b/src/aiunn/__init__.py @@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN from .upsampler.config import aiuNNConfig from .inference.inference import aiuNNInference -__version__ = "0.2.2" \ No newline at end of file +__version__ = "0.2.3" \ No newline at end of file diff --git a/src/aiunn/upsampler/aiunn.py b/src/aiunn/upsampler/aiunn.py index a5ba3b4..978a75d 100644 --- a/src/aiunn/upsampler/aiunn.py +++ b/src/aiunn/upsampler/aiunn.py @@ -17,24 +17,24 @@ class aiuNN(PreTrainedModel): # Enhanced approach scale_factor = self.config.upsample_scale - out_channels = self.base_model.config.num_channels * (scale_factor ** 2) + out_channels = self.aiia_model.config.num_channels * (scale_factor ** 2) self.pixel_shuffle_conv = nn.Conv2d( - in_channels=self.base_model.config.hidden_size, + in_channels=self.aiia_model.config.hidden_size, out_channels=out_channels, - kernel_size=self.base_model.config.kernel_size, + kernel_size=self.aiia_model.config.kernel_size, padding=1 ) self.pixel_shuffle = nn.PixelShuffle(scale_factor) def load_base_model(self, base_model: PreTrainedModel): - self.base_model = base_model + self.aiia_model = base_model def forward(self, x): - if self.base_model is None: + if self.aiia_model is None: raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.") # Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict - base_output = self.base_model(x) + base_output = self.aiia_model(x) if isinstance(base_output, tuple): x = base_output[0] elif isinstance(base_output, dict):