updated saving / modelhandling for new model version
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 9m36s
Details
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 9m36s
Details
This commit is contained in:
parent
8c4cd6612a
commit
5d0aadebab
|
@ -34,13 +34,12 @@ config = AIIAConfig()
|
|||
ai_config = aiuNNConfig()
|
||||
|
||||
base_model = AIIABase(config)
|
||||
upscaler = aiuNN(config=ai_config)
|
||||
|
||||
|
||||
# Load your base model and upscaler
|
||||
pretrained_model_path = "path/to/aiia/model"
|
||||
base_model = AIIABase.from_pretrained(pretrained_model_path)
|
||||
upscaler.load_base_model(base_model)
|
||||
|
||||
upscaler = aiuNN(config=ai_config, base_model=base_model)
|
||||
# Create trainer with your dataset class
|
||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ if __name__ =="__main__":
|
|||
# Load your base model and upscaler
|
||||
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
||||
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||
upscaler = aiuNN(base_model)
|
||||
upscaler = aiuNN(base_model,base_model=base_model)
|
||||
|
||||
# Create trainer with your dataset class
|
||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||
|
|
|
@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"]
|
|||
build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "aiunn"
|
||||
version = "0.2.4"
|
||||
version = "0.3.0"
|
||||
description = "Finetuner for image upscaling using AIIA"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||
|
||||
setup(
|
||||
name="aiunn",
|
||||
version="0.2.4",
|
||||
version="0.3.0",
|
||||
packages=find_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
install_requires=[
|
||||
|
|
|
@ -4,4 +4,4 @@ from .upsampler.aiunn import aiuNN
|
|||
from .upsampler.config import aiuNNConfig
|
||||
from .inference.inference import aiuNNInference
|
||||
|
||||
__version__ = "0.2.4"
|
||||
__version__ = "0.3.0"
|
|
@ -7,15 +7,18 @@ from transformers import PreTrainedModel
|
|||
from .config import aiuNNConfig
|
||||
import warnings
|
||||
|
||||
|
||||
class aiuNN(PreTrainedModel):
|
||||
config_class = aiuNNConfig
|
||||
def __init__(self, config: aiuNNConfig):
|
||||
def __init__(self, config: aiuNNConfig, base_model: PreTrainedModel = None):
|
||||
super().__init__(config)
|
||||
# Pass the unified base configuration using the new parameter.
|
||||
self.config = config
|
||||
if base_model is not None:
|
||||
# Only copy submodules if base_model is provided
|
||||
self.base_layers = nn.Sequential(*[layer for layer in base_model.cnn])
|
||||
else:
|
||||
# At inference, modules will be loaded from state_dict
|
||||
self.base_layers = self._build_base_layers_from_config(config)
|
||||
|
||||
# Enhanced approach
|
||||
scale_factor = self.config.upsample_scale
|
||||
out_channels = self.config.num_channels * (scale_factor ** 2)
|
||||
self.pixel_shuffle_conv = nn.Conv2d(
|
||||
|
@ -26,26 +29,28 @@ class aiuNN(PreTrainedModel):
|
|||
)
|
||||
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
||||
|
||||
def load_base_model(self, base_model: PreTrainedModel):
|
||||
self.aiia_model = base_model
|
||||
|
||||
def _build_base_layers_from_config(self, config):
|
||||
"""
|
||||
Reconstruct the base layers (e.g., CNN) using only the config.
|
||||
This must match exactly how your base model builds its layers!
|
||||
"""
|
||||
layers = []
|
||||
in_channels = config.num_channels
|
||||
for _ in range(config.num_hidden_layers):
|
||||
layers.extend([
|
||||
nn.Conv2d(in_channels, config.hidden_size,
|
||||
kernel_size=config.kernel_size, padding=1),
|
||||
getattr(nn, config.activation_function)(),
|
||||
nn.MaxPool2d(kernel_size=1, stride=1)
|
||||
])
|
||||
in_channels = config.hidden_size
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.aiia_model is None:
|
||||
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
||||
|
||||
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
|
||||
base_output = self.aiia_model(x)
|
||||
if isinstance(base_output, tuple):
|
||||
x = base_output[0]
|
||||
elif isinstance(base_output, dict):
|
||||
x = base_output.get('last_hidden_state', base_output.get('hidden_states'))
|
||||
if x is None:
|
||||
raise ValueError("Expected 'last_hidden_state' or 'hidden_states' in model output")
|
||||
else:
|
||||
x = base_output
|
||||
|
||||
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
||||
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
||||
if self.base_layers is not None:
|
||||
x = self.base_layers(x)
|
||||
x = self.pixel_shuffle_conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -57,8 +62,7 @@ if __name__ == "__main__":
|
|||
ai_config = aiuNNConfig()
|
||||
base_model = AIIABase(config)
|
||||
# Instantiate Upsampler from the base model (works correctly).
|
||||
upsampler = aiuNN(config=ai_config)
|
||||
upsampler.load_base_model(base_model)
|
||||
upsampler = aiuNN(config=ai_config, base_model=base_model)
|
||||
# Save the model (both configuration and weights).
|
||||
upsampler.save_pretrained("aiunn")
|
||||
|
||||
|
|
|
@ -21,8 +21,7 @@ def real_model(tmp_path):
|
|||
base_model = AIIABase(config)
|
||||
|
||||
# Make sure aiuNN is properly configured with all required attributes
|
||||
upsampler = aiuNN(config=ai_config)
|
||||
upsampler.load_base_model(base_model)
|
||||
upsampler = aiuNN(config=ai_config, base_model=base_model)
|
||||
|
||||
# Save the model and config to temporary directory
|
||||
save_path = str(model_dir / "save")
|
||||
|
|
|
@ -10,8 +10,7 @@ def test_save_and_load_model():
|
|||
config = AIIAConfig()
|
||||
ai_config = aiuNNConfig()
|
||||
base_model = AIIABase(config)
|
||||
upsampler = aiuNN(config=ai_config)
|
||||
upsampler.load_base_model(base_model)
|
||||
upsampler = aiuNN(config=ai_config, base_model=base_model)
|
||||
# Save the model
|
||||
save_path = os.path.join(tmpdirname, "model")
|
||||
upsampler.save_pretrained(save_path)
|
||||
|
|
Loading…
Reference in New Issue