From 81c9ae9d9d1b6ead90e783281386713c6f381cff Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 3 Mar 2025 17:39:04 +0100 Subject: [PATCH] updated model for moe --- pyproject.toml | 2 +- setup.cfg | 2 +- src/aiia/__init__.py | 2 +- src/aiia/data/DataLoader.py | 8 ++--- src/aiia/model/Model.py | 67 +++++++++++++++++++++++++++---------- 5 files changed, 57 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 203cd25..3c4b7b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = '\.pyi?$' [project] name = "aiia" -version = "0.1.3" +version = "0.1.4" description = "AIIA Deep Learning Model Implementation" readme = "README.md" authors = [ diff --git a/setup.cfg b/setup.cfg index 4b75cda..75b1b49 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiia -version = 0.1.3 +version = 0.1.4 author = Falko Habel author_email = falko.habel@gmx.de description = AIIA deep learning model implementation diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index f4edbbf..5786564 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead -__version__ = "0.1.3" +__version__ = "0.1.4" diff --git a/src/aiia/data/DataLoader.py b/src/aiia/data/DataLoader.py index 4ba5032..6f3a334 100644 --- a/src/aiia/data/DataLoader.py +++ b/src/aiia/data/DataLoader.py @@ -177,7 +177,7 @@ class AIIADataset(torch.utils.data.Dataset): self.items = items self.pretraining = pretraining self.transform = transforms.Compose([ - transforms.Resize((224, 224)), + transforms.Resize((410, 410)), transforms.ToTensor() ]) @@ -193,7 +193,7 @@ class AIIADataset(torch.utils.data.Dataset): raise ValueError(f"Invalid image at index {idx}") image = self.transform(image) - if image.shape != (3, 224, 224): + if image.shape != (3, 410, 410): raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") if task == 'denoise': @@ -215,7 +215,7 @@ class AIIADataset(torch.utils.data.Dataset): if not isinstance(image, Image.Image): raise ValueError(f"Invalid image at index {idx}") image = self.transform(image) - if image.shape != (3, 224, 224): + if image.shape != (3, 410, 410): raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") return image, label else: @@ -223,6 +223,6 @@ class AIIADataset(torch.utils.data.Dataset): image = self.transform(item) else: image = self.transform(item[0]) - if image.shape != (3, 224, 224): + if image.shape != (3, 410, 410): raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") return image diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index 3287eab..4e9296d 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -177,35 +177,68 @@ class AIIAExpert(AIIA): # Process input through the base CNN return self.base_cnn(x) - class AIIAmoe(AIIA): def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config - - # Update config with new parameters if provided + super().__init__() + self.config = config + + # Update the config to include the number of experts. self.config.num_experts = num_experts - - # Initialize multiple experts using chosen base class + + # Initialize multiple experts from the chosen base class. self.experts = nn.ModuleList([ - AIIAExpert(self.config, base_class=base_class, **kwargs) - for _ in range(self.config.num_experts) + AIIAExpert(self.config, base_class=base_class, **kwargs) + for _ in range(num_experts) ]) - - # Create gating network + + # To generate gating weights, we first need to determine the feature dimension. + # Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W, + # we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 410). + gate_in_features = 410 # Adjust this if your expert output changes. + + # Create a gating network that maps the aggregated features to num_experts weights. self.gate = nn.Sequential( - nn.Linear(self.config.hidden_size, self.config.num_experts), + nn.Linear(gate_in_features, num_experts), nn.Softmax(dim=1) ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the Mixture-of-Experts model. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: Merged output tensor from all experts + """ + # Stack the outputs from each expert. + # Each expert's output should have shape (B, C, H, W). After stacking, expert_outputs has shape: + # (B, num_experts, C, H, W) expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) - gate_weights = self.gate(torch.mean(expert_outputs, (2, 3))) - merged_output = torch.sum( - expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1 - ) + + # Aggregate spatial features: average across the spatial dimensions (H, W). + # This results in a tensor with shape (B, num_experts, C) + spatial_avg = torch.mean(expert_outputs, dim=(3, 4)) + + # To feed the gating network, further average across the expert dimension, + # obtaining a tensor of shape (B, C) that represents the global feature summary. + gate_input = torch.mean(spatial_avg, dim=1) + + # Compute gating weights using the gating network. + # The output gate_weights has shape (B, num_experts) + gate_weights = self.gate(gate_input) + + # Expand the gate weights to match the expert outputs shape so they can be combined. + # After unsqueezing, gate_weights has shape (B, num_experts, 1, 1, 1) + gate_weights_expanded = gate_weights.unsqueeze(2).unsqueeze(3).unsqueeze(4) + + # Multiply each expert's output by its corresponding gating weight and sum over experts. + # The merged_output retains the shape (B, C, H, W) + merged_output = torch.sum(expert_outputs * gate_weights_expanded, dim=1) return merged_output + class AIIAchunked(AIIA): def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs): super().__init__(config=config, **kwargs)