moved to transformer support, currently dropped fp16 load support

This commit is contained in:
Falko Victor Habel 2025-04-19 22:53:00 +02:00
parent 16f8de2175
commit ced7e8a214
1 changed files with 14 additions and 9 deletions

View File

@ -9,13 +9,12 @@ import warnings
class aiuNN(PreTrainedModel):
def __init__(self, base_model: PreTrainedModel, config:aiuNNConfig):
super().__init__(base_model.config)
self.base_model = base_model
config_class = aiuNNConfig
def __init__(self, config: aiuNNConfig):
super().__init__(config)
# Pass the unified base configuration using the new parameter.
self.config = config
# Enhanced approach
scale_factor = self.config.upsample_scale
out_channels = self.base_model.config.num_channels * (scale_factor ** 2)
@ -27,13 +26,19 @@ class aiuNN(PreTrainedModel):
)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def load_base_model(self, base_model: PreTrainedModel):
self.base_model = base_model
def forward(self, x):
if self.base_model is None:
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
x = self.base_model(x) # Get base features
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x
if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model.
@ -41,11 +46,11 @@ if __name__ == "__main__":
ai_config = aiuNNConfig()
base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(base_model, config=ai_config)
upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model (both configuration and weights).
upsampler.save_pretrained("aiunn")
# Now load using the overridden load method; this will load the complete model.
upsampler_loaded = aiuNN.load("aiunn", precision="bf16")
upsampler_loaded = aiuNN.from_pretrained("aiunn")
print("Updated configuration:", upsampler_loaded.config.__dict__)