develop #4
|
@ -1,4 +1,6 @@
|
||||||
from aiia import AIIAConfig
|
from aiia import AIIAConfig
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
class UpsamplerConfig(AIIAConfig):
|
class UpsamplerConfig(AIIAConfig):
|
||||||
|
@ -33,4 +35,3 @@ class UpsamplerConfig(AIIAConfig):
|
||||||
# Add the upsample layer only if not already present.
|
# Add the upsample layer only if not already present.
|
||||||
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
|
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
|
||||||
self.layers.append(upsample_layer)
|
self.layers.append(upsample_layer)
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import warnings
|
||||||
from aiia import AIIA, AIIAConfig, AIIABase
|
from aiia import AIIA, AIIAConfig, AIIABase
|
||||||
from config import UpsamplerConfig
|
from config import UpsamplerConfig
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
# Upsampler model that uses the configuration from the base model.
|
|
||||||
class Upsampler(AIIA):
|
class Upsampler(AIIA):
|
||||||
def __init__(self, base_model: AIIABase):
|
def __init__(self, base_model: AIIABase):
|
||||||
super().__init__(base_model.config)
|
super().__init__(base_model.config)
|
||||||
|
@ -14,43 +17,62 @@ class Upsampler(AIIA):
|
||||||
mode=self.config.upsample_mode,
|
mode=self.config.upsample_mode,
|
||||||
align_corners=self.config.upsample_align_corners
|
align_corners=self.config.upsample_align_corners
|
||||||
)
|
)
|
||||||
# Conversion layer: change from 512 channels to 3 channels.
|
# Conversion layer: change from hidden size channels to 3 channels.
|
||||||
self.to_rgb = nn.Conv2d(in_channels=self.base_model.config.hidden_size, out_channels=3, kernel_size=1)
|
self.to_rgb = nn.Conv2d(
|
||||||
|
in_channels=self.base_model.config.hidden_size,
|
||||||
|
out_channels=3,
|
||||||
|
kernel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.base_model(x)
|
x = self.base_model(x)
|
||||||
x = self.upsample(x)
|
x = self.upsample(x)
|
||||||
x = self.to_rgb(x) # Ensures output has 3 channels.
|
x = self.to_rgb(x) # Ensures output has 3 channels.
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: str):
|
def load(cls, path, precision: str = None):
|
||||||
"""
|
# Load the configuration from disk.
|
||||||
Load the model:
|
config = AIIAConfig.load(path)
|
||||||
- First, load the base model (including its configuration and state_dict).
|
# Reconstruct the base model from the loaded configuration.
|
||||||
- Then, wrap it with the Upsampler class.
|
base_model = AIIABase(config)
|
||||||
- Finally, load the combined state dictionary.
|
# Instantiate the Upsampler using the proper base model.
|
||||||
"""
|
upsampler = cls(base_model)
|
||||||
base_model = AIIABase.load(path)
|
|
||||||
instance = cls(base_model)
|
|
||||||
|
|
||||||
|
# Load state dict and handle precision conversion if needed.
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
state_dict = torch.load(f"{path}/model.pth", map_location=device)
|
state_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||||
instance.load_state_dict(state_dict)
|
if precision is not None:
|
||||||
return instance
|
if precision.lower() == 'fp16':
|
||||||
|
dtype = torch.float16
|
||||||
|
elif precision.lower() == 'bf16':
|
||||||
|
if device == 'cuda' and not torch.cuda.is_bf16_supported():
|
||||||
|
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
||||||
|
|
||||||
|
for key, param in state_dict.items():
|
||||||
|
if torch.is_tensor(param):
|
||||||
|
state_dict[key] = param.to(dtype)
|
||||||
|
upsampler.load_state_dict(state_dict)
|
||||||
|
return upsampler
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "main":
|
|
||||||
|
if __name__ == "__main__":
|
||||||
from aiia import AIIABase, AIIAConfig
|
from aiia import AIIABase, AIIAConfig
|
||||||
# Create a configuration and build a base model.
|
# Create a configuration and build a base model.
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
base_model = AIIABase("test2")
|
base_model = AIIABase(config)
|
||||||
# Instantiate Upsampler from the base model (works correctly).
|
# Instantiate Upsampler from the base model (works correctly).
|
||||||
upsampler = Upsampler(base_model)
|
upsampler = Upsampler(base_model)
|
||||||
|
|
||||||
# Save the model (both configuration and weights).
|
# Save the model (both configuration and weights).
|
||||||
upsampler.save("test2")
|
upsampler.save("hehe")
|
||||||
|
|
||||||
# Now load using the overridden load method; this will load the complete model.
|
# Now load using the overridden load method; this will load the complete model.
|
||||||
upsampler_loaded = Upsampler.load("test2")
|
upsampler_loaded = Upsampler.load("hehe", precision="bf16")
|
||||||
print("Updated configuration:", upsampler_loaded.config.__dict__)
|
print("Updated configuration:", upsampler_loaded.config.__dict__)
|
||||||
|
|
Loading…
Reference in New Issue