updated saving / modelhandling for new model version
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 9m36s
Details
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 9m36s
Details
This commit is contained in:
parent
8c4cd6612a
commit
5d0aadebab
|
@ -34,13 +34,12 @@ config = AIIAConfig()
|
||||||
ai_config = aiuNNConfig()
|
ai_config = aiuNNConfig()
|
||||||
|
|
||||||
base_model = AIIABase(config)
|
base_model = AIIABase(config)
|
||||||
upscaler = aiuNN(config=ai_config)
|
|
||||||
|
|
||||||
# Load your base model and upscaler
|
# Load your base model and upscaler
|
||||||
pretrained_model_path = "path/to/aiia/model"
|
pretrained_model_path = "path/to/aiia/model"
|
||||||
base_model = AIIABase.from_pretrained(pretrained_model_path)
|
base_model = AIIABase.from_pretrained(pretrained_model_path)
|
||||||
upscaler.load_base_model(base_model)
|
upscaler = aiuNN(config=ai_config, base_model=base_model)
|
||||||
|
|
||||||
# Create trainer with your dataset class
|
# Create trainer with your dataset class
|
||||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,7 @@ if __name__ =="__main__":
|
||||||
# Load your base model and upscaler
|
# Load your base model and upscaler
|
||||||
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
||||||
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||||
upscaler = aiuNN(base_model)
|
upscaler = aiuNN(base_model,base_model=base_model)
|
||||||
|
|
||||||
# Create trainer with your dataset class
|
# Create trainer with your dataset class
|
||||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||||
|
|
|
@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
[project]
|
[project]
|
||||||
name = "aiunn"
|
name = "aiunn"
|
||||||
version = "0.2.4"
|
version = "0.3.0"
|
||||||
description = "Finetuner for image upscaling using AIIA"
|
description = "Finetuner for image upscaling using AIIA"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="aiunn",
|
name="aiunn",
|
||||||
version="0.2.4",
|
version="0.3.0",
|
||||||
packages=find_packages(where="src"),
|
packages=find_packages(where="src"),
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
|
|
@ -4,4 +4,4 @@ from .upsampler.aiunn import aiuNN
|
||||||
from .upsampler.config import aiuNNConfig
|
from .upsampler.config import aiuNNConfig
|
||||||
from .inference.inference import aiuNNInference
|
from .inference.inference import aiuNNInference
|
||||||
|
|
||||||
__version__ = "0.2.4"
|
__version__ = "0.3.0"
|
|
@ -7,15 +7,18 @@ from transformers import PreTrainedModel
|
||||||
from .config import aiuNNConfig
|
from .config import aiuNNConfig
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class aiuNN(PreTrainedModel):
|
class aiuNN(PreTrainedModel):
|
||||||
config_class = aiuNNConfig
|
config_class = aiuNNConfig
|
||||||
def __init__(self, config: aiuNNConfig):
|
def __init__(self, config: aiuNNConfig, base_model: PreTrainedModel = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
# Pass the unified base configuration using the new parameter.
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
if base_model is not None:
|
||||||
|
# Only copy submodules if base_model is provided
|
||||||
|
self.base_layers = nn.Sequential(*[layer for layer in base_model.cnn])
|
||||||
|
else:
|
||||||
|
# At inference, modules will be loaded from state_dict
|
||||||
|
self.base_layers = self._build_base_layers_from_config(config)
|
||||||
|
|
||||||
# Enhanced approach
|
|
||||||
scale_factor = self.config.upsample_scale
|
scale_factor = self.config.upsample_scale
|
||||||
out_channels = self.config.num_channels * (scale_factor ** 2)
|
out_channels = self.config.num_channels * (scale_factor ** 2)
|
||||||
self.pixel_shuffle_conv = nn.Conv2d(
|
self.pixel_shuffle_conv = nn.Conv2d(
|
||||||
|
@ -26,26 +29,28 @@ class aiuNN(PreTrainedModel):
|
||||||
)
|
)
|
||||||
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
||||||
|
|
||||||
def load_base_model(self, base_model: PreTrainedModel):
|
def _build_base_layers_from_config(self, config):
|
||||||
self.aiia_model = base_model
|
"""
|
||||||
|
Reconstruct the base layers (e.g., CNN) using only the config.
|
||||||
|
This must match exactly how your base model builds its layers!
|
||||||
|
"""
|
||||||
|
layers = []
|
||||||
|
in_channels = config.num_channels
|
||||||
|
for _ in range(config.num_hidden_layers):
|
||||||
|
layers.extend([
|
||||||
|
nn.Conv2d(in_channels, config.hidden_size,
|
||||||
|
kernel_size=config.kernel_size, padding=1),
|
||||||
|
getattr(nn, config.activation_function)(),
|
||||||
|
nn.MaxPool2d(kernel_size=1, stride=1)
|
||||||
|
])
|
||||||
|
in_channels = config.hidden_size
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.aiia_model is None:
|
if self.base_layers is not None:
|
||||||
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
x = self.base_layers(x)
|
||||||
|
x = self.pixel_shuffle_conv(x)
|
||||||
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
|
x = self.pixel_shuffle(x)
|
||||||
base_output = self.aiia_model(x)
|
|
||||||
if isinstance(base_output, tuple):
|
|
||||||
x = base_output[0]
|
|
||||||
elif isinstance(base_output, dict):
|
|
||||||
x = base_output.get('last_hidden_state', base_output.get('hidden_states'))
|
|
||||||
if x is None:
|
|
||||||
raise ValueError("Expected 'last_hidden_state' or 'hidden_states' in model output")
|
|
||||||
else:
|
|
||||||
x = base_output
|
|
||||||
|
|
||||||
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
|
||||||
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,8 +62,7 @@ if __name__ == "__main__":
|
||||||
ai_config = aiuNNConfig()
|
ai_config = aiuNNConfig()
|
||||||
base_model = AIIABase(config)
|
base_model = AIIABase(config)
|
||||||
# Instantiate Upsampler from the base model (works correctly).
|
# Instantiate Upsampler from the base model (works correctly).
|
||||||
upsampler = aiuNN(config=ai_config)
|
upsampler = aiuNN(config=ai_config, base_model=base_model)
|
||||||
upsampler.load_base_model(base_model)
|
|
||||||
# Save the model (both configuration and weights).
|
# Save the model (both configuration and weights).
|
||||||
upsampler.save_pretrained("aiunn")
|
upsampler.save_pretrained("aiunn")
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,7 @@ def real_model(tmp_path):
|
||||||
base_model = AIIABase(config)
|
base_model = AIIABase(config)
|
||||||
|
|
||||||
# Make sure aiuNN is properly configured with all required attributes
|
# Make sure aiuNN is properly configured with all required attributes
|
||||||
upsampler = aiuNN(config=ai_config)
|
upsampler = aiuNN(config=ai_config, base_model=base_model)
|
||||||
upsampler.load_base_model(base_model)
|
|
||||||
|
|
||||||
# Save the model and config to temporary directory
|
# Save the model and config to temporary directory
|
||||||
save_path = str(model_dir / "save")
|
save_path = str(model_dir / "save")
|
||||||
|
|
|
@ -10,8 +10,7 @@ def test_save_and_load_model():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
ai_config = aiuNNConfig()
|
ai_config = aiuNNConfig()
|
||||||
base_model = AIIABase(config)
|
base_model = AIIABase(config)
|
||||||
upsampler = aiuNN(config=ai_config)
|
upsampler = aiuNN(config=ai_config, base_model=base_model)
|
||||||
upsampler.load_base_model(base_model)
|
|
||||||
# Save the model
|
# Save the model
|
||||||
save_path = os.path.join(tmpdirname, "model")
|
save_path = os.path.join(tmpdirname, "model")
|
||||||
upsampler.save_pretrained(save_path)
|
upsampler.save_pretrained(save_path)
|
||||||
|
|
Loading…
Reference in New Issue