moved to transformer support, currently dropped fp16 load support
This commit is contained in:
parent
16f8de2175
commit
ced7e8a214
|
@ -9,10 +9,9 @@ import warnings
|
||||||
|
|
||||||
|
|
||||||
class aiuNN(PreTrainedModel):
|
class aiuNN(PreTrainedModel):
|
||||||
def __init__(self, base_model: PreTrainedModel, config:aiuNNConfig):
|
config_class = aiuNNConfig
|
||||||
super().__init__(base_model.config)
|
def __init__(self, config: aiuNNConfig):
|
||||||
self.base_model = base_model
|
super().__init__(config)
|
||||||
|
|
||||||
# Pass the unified base configuration using the new parameter.
|
# Pass the unified base configuration using the new parameter.
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -27,13 +26,19 @@ class aiuNN(PreTrainedModel):
|
||||||
)
|
)
|
||||||
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
||||||
|
|
||||||
|
def load_base_model(self, base_model: PreTrainedModel):
|
||||||
|
self.base_model = base_model
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if self.base_model is None:
|
||||||
|
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
||||||
x = self.base_model(x) # Get base features
|
x = self.base_model(x) # Get base features
|
||||||
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
||||||
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from aiia import AIIABase, AIIAConfig
|
from aiia import AIIABase, AIIAConfig
|
||||||
# Create a configuration and build a base model.
|
# Create a configuration and build a base model.
|
||||||
|
@ -41,11 +46,11 @@ if __name__ == "__main__":
|
||||||
ai_config = aiuNNConfig()
|
ai_config = aiuNNConfig()
|
||||||
base_model = AIIABase(config)
|
base_model = AIIABase(config)
|
||||||
# Instantiate Upsampler from the base model (works correctly).
|
# Instantiate Upsampler from the base model (works correctly).
|
||||||
upsampler = aiuNN(base_model, config=ai_config)
|
upsampler = aiuNN(config=ai_config)
|
||||||
|
upsampler.load_base_model(base_model)
|
||||||
# Save the model (both configuration and weights).
|
# Save the model (both configuration and weights).
|
||||||
upsampler.save_pretrained("aiunn")
|
upsampler.save_pretrained("aiunn")
|
||||||
|
|
||||||
# Now load using the overridden load method; this will load the complete model.
|
# Now load using the overridden load method; this will load the complete model.
|
||||||
upsampler_loaded = aiuNN.load("aiunn", precision="bf16")
|
upsampler_loaded = aiuNN.from_pretrained("aiunn")
|
||||||
print("Updated configuration:", upsampler_loaded.config.__dict__)
|
print("Updated configuration:", upsampler_loaded.config.__dict__)
|
||||||
|
|
Loading…
Reference in New Issue