develop #11

Merged
Fabel merged 22 commits from develop into main 2025-04-04 20:12:21 +00:00
1 changed files with 11 additions and 10 deletions
Showing only changes of commit de0da5de82 - Show all commits

View File

@ -2,19 +2,18 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import warnings import warnings
from aiia import AIIAConfig, AIIABase from aiia.model.Model import AIIA, AIIAConfig, AIIABase
from aiia.model.Model import AIIA
from .config import aiuNNConfig from .config import aiuNNConfig
import warnings import warnings
class aiuNN(AIIA): class aiuNN(AIIA):
def __init__(self, base_model: AIIABase): def __init__(self, base_model: AIIA, config:aiuNNConfig):
super().__init__(base_model.config) super().__init__(base_model.config)
self.base_model = base_model self.base_model = base_model
# Pass the unified base configuration using the new parameter. # Pass the unified base configuration using the new parameter.
self.config = aiuNNConfig(base_config=base_model.config) self.config = config
# Enhanced approach # Enhanced approach
scale_factor = self.config.upsample_scale scale_factor = self.config.upsample_scale
@ -29,11 +28,12 @@ class aiuNN(AIIA):
def forward(self, x): def forward(self, x):
x = self.base_model(x) x = self.base_model(x) # Get base features
x = self.upsample(x) x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.to_rgb(x) # Ensures output has 3 channels. x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x return x
@classmethod @classmethod
def load(cls, path, precision: str = None, **kwargs): def load(cls, path, precision: str = None, **kwargs):
""" """
@ -57,7 +57,7 @@ class aiuNN(AIIA):
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device) state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
# Import all model types # Import all model types
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
# Helper function to detect base class type from key patterns # Helper function to detect base class type from key patterns
def detect_base_class_type(keys_prefix): def detect_base_class_type(keys_prefix):
@ -112,7 +112,7 @@ class aiuNN(AIIA):
base_model = base_class(config, **kwargs) base_model = base_class(config, **kwargs)
# Create the aiuNN model with the detected base model # Create the aiuNN model with the detected base model
model = cls(base_model) model = cls(base_model, config=base_model.config)
# Handle precision conversion # Handle precision conversion
dtype = None dtype = None
@ -143,9 +143,10 @@ if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model. # Create a configuration and build a base model.
config = AIIAConfig() config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config) base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly). # Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(base_model) upsampler = aiuNN(base_model, config=ai_config)
# Save the model (both configuration and weights). # Save the model (both configuration and weights).
upsampler.save("hehe") upsampler.save("hehe")