updated model for moe

This commit is contained in:
Falko Victor Habel 2025-03-03 17:39:04 +01:00
parent 899f714554
commit 81c9ae9d9d
5 changed files with 57 additions and 24 deletions

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project] [project]
name = "aiia" name = "aiia"
version = "0.1.3" version = "0.1.4"
description = "AIIA Deep Learning Model Implementation" description = "AIIA Deep Learning Model Implementation"
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = aiia name = aiia
version = 0.1.3 version = 0.1.4
author = Falko Habel author = Falko Habel
author_email = falko.habel@gmx.de author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation description = AIIA deep learning model implementation

View File

@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.1.3" __version__ = "0.1.4"

View File

@ -177,7 +177,7 @@ class AIIADataset(torch.utils.data.Dataset):
self.items = items self.items = items
self.pretraining = pretraining self.pretraining = pretraining
self.transform = transforms.Compose([ self.transform = transforms.Compose([
transforms.Resize((224, 224)), transforms.Resize((410, 410)),
transforms.ToTensor() transforms.ToTensor()
]) ])
@ -193,7 +193,7 @@ class AIIADataset(torch.utils.data.Dataset):
raise ValueError(f"Invalid image at index {idx}") raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image) image = self.transform(image)
if image.shape != (3, 224, 224): if image.shape != (3, 410, 410):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise': if task == 'denoise':
@ -215,7 +215,7 @@ class AIIADataset(torch.utils.data.Dataset):
if not isinstance(image, Image.Image): if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}") raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image) image = self.transform(image)
if image.shape != (3, 224, 224): if image.shape != (3, 410, 410):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label return image, label
else: else:
@ -223,6 +223,6 @@ class AIIADataset(torch.utils.data.Dataset):
image = self.transform(item) image = self.transform(item)
else: else:
image = self.transform(item[0]) image = self.transform(item[0])
if image.shape != (3, 224, 224): if image.shape != (3, 410, 410):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image return image

View File

@ -177,35 +177,68 @@ class AIIAExpert(AIIA):
# Process input through the base CNN # Process input through the base CNN
return self.base_cnn(x) return self.base_cnn(x)
class AIIAmoe(AIIA): class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs): def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs) super().__init__()
self.config = self.config self.config = config
# Update config with new parameters if provided # Update the config to include the number of experts.
self.config.num_experts = num_experts self.config.num_experts = num_experts
# Initialize multiple experts using chosen base class # Initialize multiple experts from the chosen base class.
self.experts = nn.ModuleList([ self.experts = nn.ModuleList([
AIIAExpert(self.config, base_class=base_class, **kwargs) AIIAExpert(self.config, base_class=base_class, **kwargs)
for _ in range(self.config.num_experts) for _ in range(num_experts)
]) ])
# Create gating network # To generate gating weights, we first need to determine the feature dimension.
# Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W,
# we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 410).
gate_in_features = 410 # Adjust this if your expert output changes.
# Create a gating network that maps the aggregated features to num_experts weights.
self.gate = nn.Sequential( self.gate = nn.Sequential(
nn.Linear(self.config.hidden_size, self.config.num_experts), nn.Linear(gate_in_features, num_experts),
nn.Softmax(dim=1) nn.Softmax(dim=1)
) )
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the Mixture-of-Experts model.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: Merged output tensor from all experts
"""
# Stack the outputs from each expert.
# Each expert's output should have shape (B, C, H, W). After stacking, expert_outputs has shape:
# (B, num_experts, C, H, W)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
merged_output = torch.sum( # Aggregate spatial features: average across the spatial dimensions (H, W).
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1 # This results in a tensor with shape (B, num_experts, C)
) spatial_avg = torch.mean(expert_outputs, dim=(3, 4))
# To feed the gating network, further average across the expert dimension,
# obtaining a tensor of shape (B, C) that represents the global feature summary.
gate_input = torch.mean(spatial_avg, dim=1)
# Compute gating weights using the gating network.
# The output gate_weights has shape (B, num_experts)
gate_weights = self.gate(gate_input)
# Expand the gate weights to match the expert outputs shape so they can be combined.
# After unsqueezing, gate_weights has shape (B, num_experts, 1, 1, 1)
gate_weights_expanded = gate_weights.unsqueeze(2).unsqueeze(3).unsqueeze(4)
# Multiply each expert's output by its corresponding gating weight and sum over experts.
# The merged_output retains the shape (B, C, H, W)
merged_output = torch.sum(expert_outputs * gate_weights_expanded, dim=1)
return merged_output return merged_output
class AIIAchunked(AIIA): class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs): def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs) super().__init__(config=config, **kwargs)