Compare commits
No commits in common. "633e6c845955346f8da9570183552219e474b15a" and "5f16b6ca6c89e0229e04b200c36d629b1bb0d01a" have entirely different histories.
633e6c8459
...
5f16b6ca6c
|
@ -23,36 +23,12 @@ class AIIA(nn.Module):
|
||||||
self.config.save(path)
|
self.config.save(path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
|
def load(cls, path, precision: str = None, **kwargs):
|
||||||
config = AIIAConfig.load(path)
|
config = AIIAConfig.load(path)
|
||||||
|
model = cls(config, **kwargs) # Pass kwargs here!
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
# Load the state dict to analyze structure
|
|
||||||
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
|
||||||
|
|
||||||
# Special handling for AIIAmoe - detect number of experts from state_dict
|
|
||||||
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
|
|
||||||
# Find maximum expert index
|
|
||||||
max_expert_idx = -1
|
|
||||||
for key in model_dict.keys():
|
|
||||||
if key.startswith("experts."):
|
|
||||||
parts = key.split(".")
|
|
||||||
if len(parts) > 1:
|
|
||||||
try:
|
|
||||||
expert_idx = int(parts[1])
|
|
||||||
max_expert_idx = max(max_expert_idx, expert_idx)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if max_expert_idx >= 0:
|
|
||||||
# experts.X keys found, use max_expert_idx + 1 as num_experts
|
|
||||||
kwargs["num_experts"] = max_expert_idx + 1
|
|
||||||
|
|
||||||
# Create model with detected structural parameters
|
|
||||||
model = cls(config, **kwargs)
|
|
||||||
|
|
||||||
# Handle precision conversion
|
|
||||||
dtype = None
|
dtype = None
|
||||||
|
|
||||||
if precision is not None:
|
if precision is not None:
|
||||||
if precision.lower() == 'fp16':
|
if precision.lower() == 'fp16':
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
@ -65,13 +41,13 @@ class AIIA(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
||||||
|
|
||||||
|
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for key, param in model_dict.items():
|
for key, param in model_dict.items():
|
||||||
if torch.is_tensor(param):
|
if torch.is_tensor(param):
|
||||||
model_dict[key] = param.to(dtype)
|
model_dict[key] = param.to(dtype)
|
||||||
|
|
||||||
# Load state dict with strict parameter for flexibility
|
model.load_state_dict(model_dict)
|
||||||
model.load_state_dict(model_dict, strict=strict)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue