fixed model loading due to a bug
This commit is contained in:
parent
b1c486afee
commit
50e91b10e8
|
@ -8,9 +8,23 @@ target-version = ['py37']
|
|||
include = '\.pyi?$'
|
||||
|
||||
[project]
|
||||
name = "AIIA"
|
||||
version = "0.1.1" # Replace with your desired version number
|
||||
description = "AIIA Deep Learning Model"
|
||||
name = "aiia"
|
||||
version = "0.1.1"
|
||||
description = "AIIA Deep Learning Model Implementation"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name="Falko Habel", email="falko.habel@gmx.de" }
|
||||
]
|
||||
dependencies = [
|
||||
"torch>=2.5.0",
|
||||
"numpy",
|
||||
"tqdm",
|
||||
"pytest",
|
||||
"pillow"
|
||||
]
|
||||
requires-python = ">=3.7"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent"
|
||||
]
|
25
run.py
25
run.py
|
@ -1,27 +1,6 @@
|
|||
data_path1 = "/root/training_data/vision-dataset/images_pretrain.parquet"
|
||||
data_path2 = "/root/training_data/vision-dataset/vector_img_pretrain.parquet"
|
||||
|
||||
from aiia.model import AIIABase
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.pretrain import Pretrainer
|
||||
|
||||
# Create your model
|
||||
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
||||
model = AIIABase(config)
|
||||
|
||||
# Initialize pretrainer with the model
|
||||
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
|
||||
|
||||
# List of dataset paths
|
||||
dataset_paths = [
|
||||
data_path1,
|
||||
data_path2
|
||||
]
|
||||
from aiia import AIIABase
|
||||
|
||||
# Start training with multiple datasets
|
||||
pretrainer.train(
|
||||
dataset_paths=dataset_paths,
|
||||
num_epochs=10,
|
||||
batch_size=2,
|
||||
sample_size=10000
|
||||
)
|
||||
model = AIIABase.load(path="AIIA-base-512", precision="bf16")
|
21
setup.py
21
setup.py
|
@ -1,25 +1,6 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="aiia",
|
||||
version="0.1.1",
|
||||
packages=find_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
install_requires=[
|
||||
"torch>=1.8.0",
|
||||
"numpy>=1.19.0",
|
||||
"tqdm>=4.62.0",
|
||||
],
|
||||
author="Falko Habel",
|
||||
author_email="falko.habel@gmx.de",
|
||||
description="AIIA deep learning model implementation",
|
||||
long_description=open("README.md").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://gitea.fabelous.app/Maschine-Learning/AIIA.git",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Creative Commons Attribution-NonCommercial 4.0 International",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires=">=3.10",
|
||||
)
|
||||
)
|
|
@ -24,17 +24,6 @@ class AIIA(nn.Module):
|
|||
|
||||
@classmethod
|
||||
def load(cls, path, precision: str = None):
|
||||
"""
|
||||
Load the model from the given path.
|
||||
|
||||
Parameters:
|
||||
- path (str): The directory containing the saved model.
|
||||
- precision (str, optional): The desired precision for model weights.
|
||||
Options are:
|
||||
'fp16' -> load weights with torch.float16,
|
||||
'bf16' -> load weights with torch.bfloat16.
|
||||
If precision is None, default torch.float32 is used.
|
||||
"""
|
||||
config = AIIAConfig.load(path)
|
||||
model = cls(config)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
@ -44,7 +33,6 @@ class AIIA(nn.Module):
|
|||
if precision.lower() == 'fp16':
|
||||
dtype = torch.float16
|
||||
elif precision.lower() == 'bf16':
|
||||
# For CUDA devices, check whether BF16 is supported. If not, fallback to FP16.
|
||||
if device == 'cuda' and not torch.cuda.is_bf16_supported():
|
||||
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
|
||||
dtype = torch.float16
|
||||
|
@ -53,14 +41,19 @@ class AIIA(nn.Module):
|
|||
else:
|
||||
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
||||
|
||||
# Load the state dictionary normally (without dtype argument)
|
||||
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||
|
||||
# If a precision conversion is requested, cast each tensor in the state dict to the target dtype.
|
||||
if dtype is not None:
|
||||
model_dict = torch.load(f"{path}/model.pth", map_location=device, dtype=dtype)
|
||||
else:
|
||||
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||
|
||||
for key, param in model_dict.items():
|
||||
if torch.is_tensor(param):
|
||||
model_dict[key] = param.to(dtype)
|
||||
|
||||
model.load_state_dict(model_dict)
|
||||
return model
|
||||
|
||||
|
||||
class AIIABase(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue