added fp16 and bf16 support when loading model

This commit is contained in:
Falko Victor Habel 2025-02-24 13:41:11 +01:00
parent f8e59c5896
commit b1c486afee
3 changed files with 72 additions and 34 deletions

View File

@ -6,3 +6,11 @@ build-backend = "setuptools.build_meta"
line-length = 88
target-version = ['py37']
include = '\.pyi?$'
[project]
name = "AIIA"
version = "0.1.1" # Replace with your desired version number
description = "AIIA Deep Learning Model"
authors = [
{ name="Falko Habel", email="falko.habel@gmx.de" }
]

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name="aiia",
version="0.1.0",
version="0.1.1",
packages=find_packages(where="src"),
package_dir={"": "src"},
install_requires=[

View File

@ -3,6 +3,7 @@ from torch import nn
import torch
import os
import copy
import warnings
class AIIA(nn.Module):
@ -22,15 +23,66 @@ class AIIA(nn.Module):
self.config.save(path)
@classmethod
def load(cls, path):
def load(cls, path, precision: str = None):
"""
Load the model from the given path.
Parameters:
- path (str): The directory containing the saved model.
- precision (str, optional): The desired precision for model weights.
Options are:
'fp16' -> load weights with torch.float16,
'bf16' -> load weights with torch.bfloat16.
If precision is None, default torch.float32 is used.
"""
config = AIIAConfig.load(path)
model = cls(config)
# Check if CUDA is available and set the device accordingly
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dictionary with the correct device mapping
model.load_state_dict(torch.load(f"{path}/model.pth", map_location=device))
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
# For CUDA devices, check whether BF16 is supported. If not, fallback to FP16.
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
model_dict = torch.load(f"{path}/model.pth", map_location=device, dtype=dtype)
else:
model_dict = torch.load(f"{path}/model.pth", map_location=device)
model.load_state_dict(model_dict)
return model
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize layers based on configuration
layers = []
in_channels = self.config.num_channels
for _ in range(self.config.num_hidden_layers):
layers.extend([
nn.Conv2d(in_channels, self.config.hidden_size,
kernel_size=self.config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(),
nn.MaxPool2d(kernel_size=1, stride=1)
])
in_channels = self.config.hidden_size
self.cnn = nn.Sequential(*layers)
def forward(self, x):
return self.cnn(x)
class AIIABaseShared(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
@ -107,29 +159,6 @@ class AIIABaseShared(AIIA):
return out
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize layers based on configuration
layers = []
in_channels = self.config.num_channels
for _ in range(self.config.num_hidden_layers):
layers.extend([
nn.Conv2d(in_channels, self.config.hidden_size,
kernel_size=self.config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(),
nn.MaxPool2d(kernel_size=1, stride=1)
])
in_channels = self.config.hidden_size
self.cnn = nn.Sequential(*layers)
def forward(self, x):
return self.cnn(x)
class AIIAExpert(AIIA):
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
@ -228,6 +257,7 @@ class AIIArecursive(AIIA):
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
model.save("test")
if __name__ =="__main__":
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
model.save("test")