improved and simplified config
This commit is contained in:
parent
68a27f00c1
commit
de0da5de82
|
@ -2,19 +2,18 @@ import os
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
from aiia import AIIAConfig, AIIABase
|
||||
from aiia.model.Model import AIIA
|
||||
from aiia.model.Model import AIIA, AIIAConfig, AIIABase
|
||||
from .config import aiuNNConfig
|
||||
import warnings
|
||||
|
||||
|
||||
class aiuNN(AIIA):
|
||||
def __init__(self, base_model: AIIABase):
|
||||
def __init__(self, base_model: AIIA, config:aiuNNConfig):
|
||||
super().__init__(base_model.config)
|
||||
self.base_model = base_model
|
||||
|
||||
# Pass the unified base configuration using the new parameter.
|
||||
self.config = aiuNNConfig(base_config=base_model.config)
|
||||
self.config = config
|
||||
|
||||
# Enhanced approach
|
||||
scale_factor = self.config.upsample_scale
|
||||
|
@ -29,11 +28,12 @@ class aiuNN(AIIA):
|
|||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base_model(x)
|
||||
x = self.upsample(x)
|
||||
x = self.to_rgb(x) # Ensures output has 3 channels.
|
||||
x = self.base_model(x) # Get base features
|
||||
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
||||
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
||||
return x
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, path, precision: str = None, **kwargs):
|
||||
"""
|
||||
|
@ -57,7 +57,7 @@ class aiuNN(AIIA):
|
|||
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
|
||||
|
||||
# Import all model types
|
||||
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
|
||||
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
|
||||
|
||||
# Helper function to detect base class type from key patterns
|
||||
def detect_base_class_type(keys_prefix):
|
||||
|
@ -112,7 +112,7 @@ class aiuNN(AIIA):
|
|||
base_model = base_class(config, **kwargs)
|
||||
|
||||
# Create the aiuNN model with the detected base model
|
||||
model = cls(base_model)
|
||||
model = cls(base_model, config=base_model.config)
|
||||
|
||||
# Handle precision conversion
|
||||
dtype = None
|
||||
|
@ -143,9 +143,10 @@ if __name__ == "__main__":
|
|||
from aiia import AIIABase, AIIAConfig
|
||||
# Create a configuration and build a base model.
|
||||
config = AIIAConfig()
|
||||
ai_config = aiuNNConfig()
|
||||
base_model = AIIABase(config)
|
||||
# Instantiate Upsampler from the base model (works correctly).
|
||||
upsampler = aiuNN(base_model)
|
||||
upsampler = aiuNN(base_model, config=ai_config)
|
||||
|
||||
# Save the model (both configuration and weights).
|
||||
upsampler.save("hehe")
|
||||
|
|
Loading…
Reference in New Issue