Merge pull request 'feat/model_fix' (#19) from feat/model_fix into main
Reviewed-on: #19
This commit is contained in:
commit
a9b35dc4e0
|
@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"]
|
|||
build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "aiunn"
|
||||
version = "0.2.2"
|
||||
version = "0.2.3"
|
||||
description = "Finetuner for image upscaling using AIIA"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||
|
||||
setup(
|
||||
name="aiunn",
|
||||
version="0.2.2",
|
||||
version="0.2.3",
|
||||
packages=find_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
install_requires=[
|
||||
|
|
|
@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
|
|||
from .upsampler.config import aiuNNConfig
|
||||
from .inference.inference import aiuNNInference
|
||||
|
||||
__version__ = "0.2.2"
|
||||
__version__ = "0.2.3"
|
|
@ -17,24 +17,24 @@ class aiuNN(PreTrainedModel):
|
|||
|
||||
# Enhanced approach
|
||||
scale_factor = self.config.upsample_scale
|
||||
out_channels = self.base_model.config.num_channels * (scale_factor ** 2)
|
||||
out_channels = self.aiia_model.config.num_channels * (scale_factor ** 2)
|
||||
self.pixel_shuffle_conv = nn.Conv2d(
|
||||
in_channels=self.base_model.config.hidden_size,
|
||||
in_channels=self.aiia_model.config.hidden_size,
|
||||
out_channels=out_channels,
|
||||
kernel_size=self.base_model.config.kernel_size,
|
||||
kernel_size=self.aiia_model.config.kernel_size,
|
||||
padding=1
|
||||
)
|
||||
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
||||
|
||||
def load_base_model(self, base_model: PreTrainedModel):
|
||||
self.base_model = base_model
|
||||
self.aiia_model = base_model
|
||||
|
||||
def forward(self, x):
|
||||
if self.base_model is None:
|
||||
if self.aiia_model is None:
|
||||
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
||||
|
||||
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
|
||||
base_output = self.base_model(x)
|
||||
base_output = self.aiia_model(x)
|
||||
if isinstance(base_output, tuple):
|
||||
x = base_output[0]
|
||||
elif isinstance(base_output, dict):
|
||||
|
|
Loading…
Reference in New Issue