improved and simplified config

This commit is contained in:
Falko Victor Habel 2025-03-27 17:18:40 +01:00
parent 68a27f00c1
commit de0da5de82
1 changed files with 11 additions and 10 deletions

View File

@ -2,19 +2,18 @@ import os
import torch
import torch.nn as nn
import warnings
from aiia import AIIAConfig, AIIABase
from aiia.model.Model import AIIA
from aiia.model.Model import AIIA, AIIAConfig, AIIABase
from .config import aiuNNConfig
import warnings
class aiuNN(AIIA):
def __init__(self, base_model: AIIABase):
def __init__(self, base_model: AIIA, config:aiuNNConfig):
super().__init__(base_model.config)
self.base_model = base_model
# Pass the unified base configuration using the new parameter.
self.config = aiuNNConfig(base_config=base_model.config)
self.config = config
# Enhanced approach
scale_factor = self.config.upsample_scale
@ -29,11 +28,12 @@ class aiuNN(AIIA):
def forward(self, x):
x = self.base_model(x)
x = self.upsample(x)
x = self.to_rgb(x) # Ensures output has 3 channels.
x = self.base_model(x) # Get base features
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x
@classmethod
def load(cls, path, precision: str = None, **kwargs):
"""
@ -57,7 +57,7 @@ class aiuNN(AIIA):
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
# Import all model types
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
# Helper function to detect base class type from key patterns
def detect_base_class_type(keys_prefix):
@ -112,7 +112,7 @@ class aiuNN(AIIA):
base_model = base_class(config, **kwargs)
# Create the aiuNN model with the detected base model
model = cls(base_model)
model = cls(base_model, config=base_model.config)
# Handle precision conversion
dtype = None
@ -143,9 +143,10 @@ if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model.
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(base_model)
upsampler = aiuNN(base_model, config=ai_config)
# Save the model (both configuration and weights).
upsampler.save("hehe")