updated model for moe
This commit is contained in:
parent
899f714554
commit
81c9ae9d9d
|
@ -10,7 +10,7 @@ include = '\.pyi?$'
|
|||
|
||||
[project]
|
||||
name = "aiia"
|
||||
version = "0.1.3"
|
||||
version = "0.1.4"
|
||||
description = "AIIA Deep Learning Model Implementation"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[metadata]
|
||||
name = aiia
|
||||
version = 0.1.3
|
||||
version = 0.1.4
|
||||
author = Falko Habel
|
||||
author_email = falko.habel@gmx.de
|
||||
description = AIIA deep learning model implementation
|
||||
|
|
|
@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
|
|||
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||
|
||||
|
||||
__version__ = "0.1.3"
|
||||
__version__ = "0.1.4"
|
||||
|
|
|
@ -177,7 +177,7 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
self.items = items
|
||||
self.pretraining = pretraining
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.Resize((410, 410)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
|
@ -193,7 +193,7 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
raise ValueError(f"Invalid image at index {idx}")
|
||||
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
if image.shape != (3, 410, 410):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
|
||||
if task == 'denoise':
|
||||
|
@ -215,7 +215,7 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
if image.shape != (3, 410, 410):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
return image, label
|
||||
else:
|
||||
|
@ -223,6 +223,6 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
image = self.transform(item)
|
||||
else:
|
||||
image = self.transform(item[0])
|
||||
if image.shape != (3, 224, 224):
|
||||
if image.shape != (3, 410, 410):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
return image
|
||||
|
|
|
@ -177,35 +177,68 @@ class AIIAExpert(AIIA):
|
|||
# Process input through the base CNN
|
||||
return self.base_cnn(x)
|
||||
|
||||
|
||||
class AIIAmoe(AIIA):
|
||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Update config with new parameters if provided
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Update the config to include the number of experts.
|
||||
self.config.num_experts = num_experts
|
||||
|
||||
# Initialize multiple experts using chosen base class
|
||||
|
||||
# Initialize multiple experts from the chosen base class.
|
||||
self.experts = nn.ModuleList([
|
||||
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
||||
for _ in range(self.config.num_experts)
|
||||
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
|
||||
# Create gating network
|
||||
|
||||
# To generate gating weights, we first need to determine the feature dimension.
|
||||
# Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W,
|
||||
# we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 410).
|
||||
gate_in_features = 410 # Adjust this if your expert output changes.
|
||||
|
||||
# Create a gating network that maps the aggregated features to num_experts weights.
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(self.config.hidden_size, self.config.num_experts),
|
||||
nn.Linear(gate_in_features, num_experts),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the Mixture-of-Experts model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Merged output tensor from all experts
|
||||
"""
|
||||
# Stack the outputs from each expert.
|
||||
# Each expert's output should have shape (B, C, H, W). After stacking, expert_outputs has shape:
|
||||
# (B, num_experts, C, H, W)
|
||||
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
||||
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
|
||||
merged_output = torch.sum(
|
||||
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1
|
||||
)
|
||||
|
||||
# Aggregate spatial features: average across the spatial dimensions (H, W).
|
||||
# This results in a tensor with shape (B, num_experts, C)
|
||||
spatial_avg = torch.mean(expert_outputs, dim=(3, 4))
|
||||
|
||||
# To feed the gating network, further average across the expert dimension,
|
||||
# obtaining a tensor of shape (B, C) that represents the global feature summary.
|
||||
gate_input = torch.mean(spatial_avg, dim=1)
|
||||
|
||||
# Compute gating weights using the gating network.
|
||||
# The output gate_weights has shape (B, num_experts)
|
||||
gate_weights = self.gate(gate_input)
|
||||
|
||||
# Expand the gate weights to match the expert outputs shape so they can be combined.
|
||||
# After unsqueezing, gate_weights has shape (B, num_experts, 1, 1, 1)
|
||||
gate_weights_expanded = gate_weights.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
||||
|
||||
# Multiply each expert's output by its corresponding gating weight and sum over experts.
|
||||
# The merged_output retains the shape (B, C, H, W)
|
||||
merged_output = torch.sum(expert_outputs * gate_weights_expanded, dim=1)
|
||||
return merged_output
|
||||
|
||||
|
||||
class AIIAchunked(AIIA):
|
||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue