develop #18
|
@ -1,4 +1,4 @@
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include README.md
|
include README.md
|
||||||
include requirements.txt
|
include requirements.txt
|
||||||
recursive-include src/aiia *
|
recursive-include src/aiunn *
|
|
@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
[project]
|
[project]
|
||||||
name = "aiunn"
|
name = "aiunn"
|
||||||
version = "0.1.1"
|
version = "0.2.2"
|
||||||
description = "Finetuner for image upscaling using AIIA"
|
description = "Finetuner for image upscaling using AIIA"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
torch
|
torch
|
||||||
aiia
|
|
||||||
pillow
|
pillow
|
||||||
|
pandas
|
||||||
torchvision
|
torchvision
|
||||||
scikit-learn
|
scikit-learn
|
||||||
git+https://gitea.fabelous.app/Machine-Learning/AIIA.git
|
git+https://gitea.fabelous.app/Machine-Learning/AIIA.git
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="aiunn",
|
name="aiunn",
|
||||||
version="0.2.1",
|
version="0.2.2",
|
||||||
packages=find_packages(where="src"),
|
packages=find_packages(where="src"),
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
|
|
@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
|
||||||
from .upsampler.config import aiuNNConfig
|
from .upsampler.config import aiuNNConfig
|
||||||
from .inference.inference import aiuNNInference
|
from .inference.inference import aiuNNInference
|
||||||
|
|
||||||
__version__ = "0.2.1"
|
__version__ = "0.2.2"
|
|
@ -32,7 +32,18 @@ class aiuNN(PreTrainedModel):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.base_model is None:
|
if self.base_model is None:
|
||||||
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
||||||
x = self.base_model(x) # Get base features
|
|
||||||
|
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
|
||||||
|
base_output = self.base_model(x)
|
||||||
|
if isinstance(base_output, tuple):
|
||||||
|
x = base_output[0]
|
||||||
|
elif isinstance(base_output, dict):
|
||||||
|
x = base_output.get('last_hidden_state', base_output.get('hidden_states'))
|
||||||
|
if x is None:
|
||||||
|
raise ValueError("Expected 'last_hidden_state' or 'hidden_states' in model output")
|
||||||
|
else:
|
||||||
|
x = base_output
|
||||||
|
|
||||||
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
||||||
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
||||||
return x
|
return x
|
||||||
|
|
Loading…
Reference in New Issue