moved to transformer support, currently dropped fp16 load support

This commit is contained in:
Falko Victor Habel 2025-04-19 22:53:00 +02:00
parent 16f8de2175
commit ced7e8a214
1 changed files with 14 additions and 9 deletions

View File

@ -9,10 +9,9 @@ import warnings
class aiuNN(PreTrainedModel): class aiuNN(PreTrainedModel):
def __init__(self, base_model: PreTrainedModel, config:aiuNNConfig): config_class = aiuNNConfig
super().__init__(base_model.config) def __init__(self, config: aiuNNConfig):
self.base_model = base_model super().__init__(config)
# Pass the unified base configuration using the new parameter. # Pass the unified base configuration using the new parameter.
self.config = config self.config = config
@ -27,13 +26,19 @@ class aiuNN(PreTrainedModel):
) )
self.pixel_shuffle = nn.PixelShuffle(scale_factor) self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def load_base_model(self, base_model: PreTrainedModel):
self.base_model = base_model
def forward(self, x): def forward(self, x):
if self.base_model is None:
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
x = self.base_model(x) # Get base features x = self.base_model(x) # Get base features
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x return x
if __name__ == "__main__": if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model. # Create a configuration and build a base model.
@ -41,11 +46,11 @@ if __name__ == "__main__":
ai_config = aiuNNConfig() ai_config = aiuNNConfig()
base_model = AIIABase(config) base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly). # Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(base_model, config=ai_config) upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model (both configuration and weights). # Save the model (both configuration and weights).
upsampler.save_pretrained("aiunn") upsampler.save_pretrained("aiunn")
# Now load using the overridden load method; this will load the complete model. # Now load using the overridden load method; this will load the complete model.
upsampler_loaded = aiuNN.load("aiunn", precision="bf16") upsampler_loaded = aiuNN.from_pretrained("aiunn")
print("Updated configuration:", upsampler_loaded.config.__dict__) print("Updated configuration:", upsampler_loaded.config.__dict__)