Merge pull request 'feat/tests' (#32) from feat/tests into main
Reviewed-on: #32
This commit is contained in:
commit
e0b3cfb20f
|
@ -0,0 +1,10 @@
|
||||||
|
[run]
|
||||||
|
branch = True
|
||||||
|
source = src
|
||||||
|
omit =
|
||||||
|
*/tests/*
|
||||||
|
*/migrations/*
|
||||||
|
|
||||||
|
[report]
|
||||||
|
show_missing = True
|
||||||
|
fail_under = 80
|
|
@ -0,0 +1,36 @@
|
||||||
|
name: Run VectorLoader Script
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
Explore-Gitea-Actions:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container: catthehacker/ubuntu:act-latest
|
||||||
|
steps:
|
||||||
|
- name: Check out repository code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: '3.11.7'
|
||||||
|
|
||||||
|
- name: Clone additional repository
|
||||||
|
run: |
|
||||||
|
git config --global credential.helper cache
|
||||||
|
git clone https://fabel:${{ secrets.CICD }}@gitea.fabelous.app/fabel/VectorLoader.git
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
cd VectorLoader
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
- name: Run vectorizing
|
||||||
|
env:
|
||||||
|
VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }}
|
||||||
|
run: |
|
||||||
|
python -m src.run --full
|
|
@ -0,0 +1,37 @@
|
||||||
|
name: Gitea Actions For AIIA
|
||||||
|
run-name: ${{ gitea.actor }} is testing out Gitea Actions 🚀
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
Explore-Gitea-Actions:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container: catthehacker/ubuntu:act-latest
|
||||||
|
steps:
|
||||||
|
- name: Check out repository code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11.7'
|
||||||
|
|
||||||
|
- name: Cache pip and model
|
||||||
|
uses: actions/cache@v3
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cache/pip
|
||||||
|
./fabel
|
||||||
|
key: ${{ runner.os }}-pip-model-${{ hashFiles('requirements.txt', 'requirements-dev.txt') }}
|
||||||
|
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pip-model-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install -r requirements-dev.txt
|
||||||
|
pip install -e .
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
pytest tests/
|
|
@ -10,7 +10,7 @@ include = '\.pyi?$'
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "aiia"
|
name = "aiia"
|
||||||
version = "0.1.6"
|
version = "0.2.0"
|
||||||
description = "AIIA Deep Learning Model Implementation"
|
description = "AIIA Deep Learning Model Implementation"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
[pytest]
|
||||||
|
testpaths = tests/
|
||||||
|
python_files = test_*.py
|
|
@ -0,0 +1,2 @@
|
||||||
|
pytest
|
||||||
|
pytest-mock
|
|
@ -1,6 +1,6 @@
|
||||||
[metadata]
|
[metadata]
|
||||||
name = aiia
|
name = aiia
|
||||||
version = 0.1.6
|
version = 0.2.0
|
||||||
author = Falko Habel
|
author = Falko Habel
|
||||||
author_email = falko.habel@gmx.de
|
author_email = falko.habel@gmx.de
|
||||||
description = AIIA deep learning model implementation
|
description = AIIA deep learning model implementation
|
||||||
|
|
|
@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
|
||||||
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.6"
|
__version__ = "0.2.0"
|
||||||
|
|
|
@ -14,10 +14,10 @@ class FilePathLoader:
|
||||||
self.label_column = label_column
|
self.label_column = label_column
|
||||||
self.successful_count = 0
|
self.successful_count = 0
|
||||||
self.skipped_count = 0
|
self.skipped_count = 0
|
||||||
|
|
||||||
if self.file_path_column not in dataset.column_names:
|
if self.file_path_column not in dataset.columns:
|
||||||
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
||||||
|
|
||||||
def _get_image(self, item):
|
def _get_image(self, item):
|
||||||
try:
|
try:
|
||||||
path = item[self.file_path_column]
|
path = item[self.file_path_column]
|
||||||
|
@ -32,7 +32,7 @@ class FilePathLoader:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading image from {path}: {e}")
|
print(f"Error loading image from {path}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_item(self, idx):
|
def get_item(self, idx):
|
||||||
item = self.dataset.iloc[idx]
|
item = self.dataset.iloc[idx]
|
||||||
image = self._get_image(item)
|
image = self._get_image(item)
|
||||||
|
@ -46,7 +46,7 @@ class FilePathLoader:
|
||||||
else:
|
else:
|
||||||
self.skipped_count += 1
|
self.skipped_count += 1
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
print(f"Successfully converted {self.successful_count} images.")
|
print(f"Successfully converted {self.successful_count} images.")
|
||||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
@ -58,14 +58,14 @@ class JPGImageLoader:
|
||||||
self.label_column = label_column
|
self.label_column = label_column
|
||||||
self.successful_count = 0
|
self.successful_count = 0
|
||||||
self.skipped_count = 0
|
self.skipped_count = 0
|
||||||
|
|
||||||
if self.bytes_column not in dataset.columns:
|
if self.bytes_column not in dataset.columns:
|
||||||
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
||||||
|
|
||||||
def _get_image(self, item):
|
def _get_image(self, item):
|
||||||
try:
|
try:
|
||||||
data = item[self.bytes_column]
|
data = item[self.bytes_column]
|
||||||
|
|
||||||
if isinstance(data, str) and data.startswith("b'"):
|
if isinstance(data, str) and data.startswith("b'"):
|
||||||
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
||||||
bytes_data = cleaned_data
|
bytes_data = cleaned_data
|
||||||
|
@ -73,7 +73,7 @@ class JPGImageLoader:
|
||||||
bytes_data = base64.b64decode(data)
|
bytes_data = base64.b64decode(data)
|
||||||
else:
|
else:
|
||||||
bytes_data = data
|
bytes_data = data
|
||||||
|
|
||||||
img_bytes = io.BytesIO(bytes_data)
|
img_bytes = io.BytesIO(bytes_data)
|
||||||
image = Image.open(img_bytes)
|
image = Image.open(img_bytes)
|
||||||
if image.mode == 'RGBA':
|
if image.mode == 'RGBA':
|
||||||
|
@ -86,7 +86,7 @@ class JPGImageLoader:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading image from bytes: {e}")
|
print(f"Error loading image from bytes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_item(self, idx):
|
def get_item(self, idx):
|
||||||
item = self.dataset.iloc[idx]
|
item = self.dataset.iloc[idx]
|
||||||
image = self._get_image(item)
|
image = self._get_image(item)
|
||||||
|
@ -100,37 +100,41 @@ class JPGImageLoader:
|
||||||
else:
|
else:
|
||||||
self.skipped_count += 1
|
self.skipped_count += 1
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
print(f"Successfully converted {self.successful_count} images.")
|
print(f"Successfully converted {self.successful_count} images.")
|
||||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||||
|
|
||||||
class AIIADataLoader:
|
class AIIADataLoader:
|
||||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path",
|
||||||
|
label_column=None, pretraining=False, **dataloader_kwargs):
|
||||||
|
if column not in dataset.columns:
|
||||||
|
raise ValueError(f"Column '{column}' not found in dataset")
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.pretraining = pretraining
|
self.pretraining = pretraining
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
sample_value = dataset[column].iloc[0]
|
sample_value = dataset[column].iloc[0]
|
||||||
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
||||||
isinstance(sample_value, bytes) or
|
isinstance(sample_value, bytes) or
|
||||||
sample_value.startswith("b'") or
|
sample_value.startswith("b'") or
|
||||||
sample_value.startswith(('b"', 'data:image'))
|
sample_value.startswith(('b"', 'data:image'))
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_bytes_or_bytestring:
|
if is_bytes_or_bytestring:
|
||||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
else:
|
else:
|
||||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
||||||
|
|
||||||
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
||||||
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
||||||
else:
|
else:
|
||||||
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
||||||
|
|
||||||
self.items = []
|
self.items = []
|
||||||
for idx in range(len(dataset)):
|
for idx in range(len(dataset)):
|
||||||
item = self.loader.get_item(idx)
|
item = self.loader.get_item(idx)
|
||||||
|
@ -141,33 +145,32 @@ class AIIADataLoader:
|
||||||
self.items.append((img, 'rotate', 0))
|
self.items.append((img, 'rotate', 0))
|
||||||
else:
|
else:
|
||||||
self.items.append(item)
|
self.items.append(item)
|
||||||
|
|
||||||
if not self.items:
|
if not self.items:
|
||||||
raise ValueError("No valid items were loaded from the dataset")
|
raise ValueError("No valid items were loaded from the dataset")
|
||||||
|
|
||||||
|
|
||||||
train_indices, val_indices = self._split_data()
|
train_indices, val_indices = self._split_data()
|
||||||
|
|
||||||
self.train_dataset = self._create_subset(train_indices)
|
self.train_dataset = self._create_subset(train_indices)
|
||||||
self.val_dataset = self._create_subset(val_indices)
|
self.val_dataset = self._create_subset(val_indices)
|
||||||
|
|
||||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
||||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
||||||
|
|
||||||
def _split_data(self):
|
def _split_data(self):
|
||||||
if len(self.items) == 0:
|
if len(self.items) == 0:
|
||||||
raise ValueError("No items to split")
|
raise ValueError("No items to split")
|
||||||
|
|
||||||
num_samples = len(self.items)
|
num_samples = len(self.items)
|
||||||
indices = list(range(num_samples))
|
indices = list(range(num_samples))
|
||||||
random.shuffle(indices)
|
random.shuffle(indices)
|
||||||
|
|
||||||
split_idx = int((1 - self.val_split) * num_samples)
|
split_idx = int((1 - self.val_split) * num_samples)
|
||||||
train_indices = indices[:split_idx]
|
train_indices = indices[:split_idx]
|
||||||
val_indices = indices[split_idx:]
|
val_indices = indices[split_idx:]
|
||||||
|
|
||||||
return train_indices, val_indices
|
return train_indices, val_indices
|
||||||
|
|
||||||
def _create_subset(self, indices):
|
def _create_subset(self, indices):
|
||||||
subset_items = [self.items[i] for i in indices]
|
subset_items = [self.items[i] for i in indices]
|
||||||
return AIIADataset(subset_items, pretraining=self.pretraining)
|
return AIIADataset(subset_items, pretraining=self.pretraining)
|
||||||
|
@ -180,22 +183,24 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
transforms.Resize((224, 224)),
|
transforms.Resize((224, 224)),
|
||||||
transforms.ToTensor()
|
transforms.ToTensor()
|
||||||
])
|
])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.items)
|
return len(self.items)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.items[idx]
|
item = self.items[idx]
|
||||||
|
|
||||||
if self.pretraining:
|
if self.pretraining:
|
||||||
image, task, label = item
|
image, task, label = item
|
||||||
if not isinstance(image, Image.Image):
|
if not isinstance(image, Image.Image):
|
||||||
raise ValueError(f"Invalid image at index {idx}")
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
|
||||||
|
# Check image dimensions before transform
|
||||||
|
if image.size[0] < 224 or image.size[1] < 224:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
if image.shape != (3, 224, 224):
|
|
||||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
noise_std = 0.1
|
noise_std = 0.1
|
||||||
noisy_img = image + torch.randn_like(image) * noise_std
|
noisy_img = image + torch.randn_like(image) * noise_std
|
||||||
|
@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
image, label = item
|
image, label = item
|
||||||
if not isinstance(image, Image.Image):
|
if not isinstance(image, Image.Image):
|
||||||
raise ValueError(f"Invalid image at index {idx}")
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
|
||||||
|
# Check image dimensions before transform
|
||||||
|
if image.size[0] < 224 or image.size[1] < 224:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
if image.shape != (3, 224, 224):
|
|
||||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
|
||||||
return image, label
|
return image, label
|
||||||
else:
|
else:
|
||||||
if isinstance(item, Image.Image):
|
image = item[0] if isinstance(item, tuple) else item
|
||||||
image = self.transform(item)
|
if not isinstance(image, Image.Image):
|
||||||
else:
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
image = self.transform(item[0])
|
|
||||||
if image.shape != (3, 224, 224):
|
# Check image dimensions before transform
|
||||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
if image.size[0] < 224 or image.size[1] < 224:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
image = self.transform(image)
|
||||||
return image
|
return image
|
||||||
|
|
|
@ -23,12 +23,36 @@ class AIIA(nn.Module):
|
||||||
self.config.save(path)
|
self.config.save(path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path, precision: str = None, **kwargs):
|
def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
|
||||||
config = AIIAConfig.load(path)
|
config = AIIAConfig.load(path)
|
||||||
model = cls(config, **kwargs) # Pass kwargs here!
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
# Load the state dict to analyze structure
|
||||||
|
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||||
|
|
||||||
|
# Special handling for AIIAmoe - detect number of experts from state_dict
|
||||||
|
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
|
||||||
|
# Find maximum expert index
|
||||||
|
max_expert_idx = -1
|
||||||
|
for key in model_dict.keys():
|
||||||
|
if key.startswith("experts."):
|
||||||
|
parts = key.split(".")
|
||||||
|
if len(parts) > 1:
|
||||||
|
try:
|
||||||
|
expert_idx = int(parts[1])
|
||||||
|
max_expert_idx = max(max_expert_idx, expert_idx)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if max_expert_idx >= 0:
|
||||||
|
# experts.X keys found, use max_expert_idx + 1 as num_experts
|
||||||
|
kwargs["num_experts"] = max_expert_idx + 1
|
||||||
|
|
||||||
|
# Create model with detected structural parameters
|
||||||
|
model = cls(config, **kwargs)
|
||||||
|
|
||||||
|
# Handle precision conversion
|
||||||
dtype = None
|
dtype = None
|
||||||
|
|
||||||
if precision is not None:
|
if precision is not None:
|
||||||
if precision.lower() == 'fp16':
|
if precision.lower() == 'fp16':
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
@ -40,14 +64,14 @@ class AIIA(nn.Module):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
||||||
|
|
||||||
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for key, param in model_dict.items():
|
for key, param in model_dict.items():
|
||||||
if torch.is_tensor(param):
|
if torch.is_tensor(param):
|
||||||
model_dict[key] = param.to(dtype)
|
model_dict[key] = param.to(dtype)
|
||||||
|
|
||||||
model.load_state_dict(model_dict)
|
# Load state dict with strict parameter for flexibility
|
||||||
|
model.load_state_dict(model_dict, strict=strict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
|
||||||
|
|
||||||
|
def create_sample_dataset(file_paths=None):
|
||||||
|
if file_paths is None:
|
||||||
|
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'file_path': file_paths,
|
||||||
|
'label': [0] * len(file_paths) # Match length of labels to file_paths
|
||||||
|
}
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def create_sample_bytes_dataset(bytes_data=None):
|
||||||
|
if bytes_data is None:
|
||||||
|
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'jpg': bytes_data,
|
||||||
|
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
|
||||||
|
}
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def test_file_path_loader(mocker):
|
||||||
|
# Mock Image.open to return a fake image
|
||||||
|
mock_image = Image.new('RGB', (224, 224))
|
||||||
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||||
|
|
||||||
|
dataset = create_sample_dataset()
|
||||||
|
loader = FilePathLoader(dataset, label_column='label') # Added label_column
|
||||||
|
item = loader.get_item(0)
|
||||||
|
assert isinstance(item[0], Image.Image)
|
||||||
|
assert item[1] == 0
|
||||||
|
|
||||||
|
loader.print_summary()
|
||||||
|
|
||||||
|
def test_jpg_image_loader(mocker):
|
||||||
|
# Mock Image.open to return a fake image
|
||||||
|
mock_image = Image.new('RGB', (224, 224))
|
||||||
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||||
|
|
||||||
|
dataset = create_sample_bytes_dataset()
|
||||||
|
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
|
||||||
|
item = loader.get_item(0)
|
||||||
|
assert isinstance(item[0], Image.Image)
|
||||||
|
assert item[1] == 0
|
||||||
|
|
||||||
|
loader.print_summary()
|
||||||
|
|
||||||
|
def test_aiia_data_loader(mocker):
|
||||||
|
# Mock Image.open to return a fake image
|
||||||
|
mock_image = Image.new('RGB', (224, 224))
|
||||||
|
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||||
|
|
||||||
|
dataset = create_sample_dataset()
|
||||||
|
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
|
||||||
|
|
||||||
|
# Test train loader
|
||||||
|
batch = next(iter(data_loader.train_loader))
|
||||||
|
assert isinstance(batch, list)
|
||||||
|
assert len(batch) == 2 # (images, labels)
|
||||||
|
assert batch[0].shape[0] == 1 # batch size
|
||||||
|
|
||||||
|
def test_aiia_dataset():
|
||||||
|
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
|
||||||
|
dataset = AIIADataset(items)
|
||||||
|
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert isinstance(item[0], torch.Tensor)
|
||||||
|
assert item[1] == 0
|
||||||
|
|
||||||
|
def test_aiia_dataset_pre_training():
|
||||||
|
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
|
||||||
|
dataset = AIIADataset(items, pretraining=True)
|
||||||
|
|
||||||
|
assert len(dataset) == 1
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert isinstance(item[0], torch.Tensor)
|
||||||
|
assert isinstance(item[2], str)
|
||||||
|
|
||||||
|
def test_aiia_dataset_invalid_image():
|
||||||
|
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
|
||||||
|
dataset = AIIADataset(items)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid image dimensions"):
|
||||||
|
dataset[0]
|
||||||
|
|
||||||
|
def test_aiia_dataset_invalid_task():
|
||||||
|
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
|
||||||
|
dataset = AIIADataset(items, pretraining=True)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dataset[0]
|
||||||
|
|
||||||
|
def test_aiia_data_loader_invalid_column():
|
||||||
|
dataset = create_sample_dataset()
|
||||||
|
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
|
||||||
|
AIIADataLoader(dataset, column='invalid_column')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(['-v'])
|
|
@ -0,0 +1,133 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
|
||||||
|
|
||||||
|
def test_aiiabase_creation():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIABase(config)
|
||||||
|
assert isinstance(model, AIIABase)
|
||||||
|
|
||||||
|
def test_aiiabase_save_load():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIABase(config)
|
||||||
|
save_path = "test_aiiabase_save_load"
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
model.save(save_path)
|
||||||
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||||
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
loaded_model = AIIABase.load(save_path)
|
||||||
|
|
||||||
|
# Check if the loaded model is an instance of AIIABase
|
||||||
|
assert isinstance(loaded_model, AIIABase)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
os.remove(os.path.join(save_path, "model.pth"))
|
||||||
|
os.remove(os.path.join(save_path, "config.json"))
|
||||||
|
os.rmdir(save_path)
|
||||||
|
|
||||||
|
def test_aiiabase_shared_creation():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIABaseShared(config)
|
||||||
|
assert isinstance(model, AIIABaseShared)
|
||||||
|
|
||||||
|
def test_aiiabase_shared_save_load():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIABaseShared(config)
|
||||||
|
save_path = "test_aiiabase_shared_save_load"
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
model.save(save_path)
|
||||||
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||||
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
loaded_model = AIIABaseShared.load(save_path)
|
||||||
|
|
||||||
|
# Check if the loaded model is an instance of AIIABaseShared
|
||||||
|
assert isinstance(loaded_model, AIIABaseShared)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
os.remove(os.path.join(save_path, "model.pth"))
|
||||||
|
os.remove(os.path.join(save_path, "config.json"))
|
||||||
|
os.rmdir(save_path)
|
||||||
|
|
||||||
|
def test_aiiaexpert_creation():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIAExpert(config)
|
||||||
|
assert isinstance(model, AIIAExpert)
|
||||||
|
|
||||||
|
def test_aiiaexpert_save_load():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIAExpert(config)
|
||||||
|
save_path = "test_aiiaexpert_save_load"
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
model.save(save_path)
|
||||||
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||||
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
loaded_model = AIIAExpert.load(save_path)
|
||||||
|
|
||||||
|
# Check if the loaded model is an instance of AIIAExpert
|
||||||
|
assert isinstance(loaded_model, AIIAExpert)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
os.remove(os.path.join(save_path, "model.pth"))
|
||||||
|
os.remove(os.path.join(save_path, "config.json"))
|
||||||
|
os.rmdir(save_path)
|
||||||
|
|
||||||
|
def test_aiiamoe_creation():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIAmoe(config, num_experts=5)
|
||||||
|
assert isinstance(model, AIIAmoe)
|
||||||
|
|
||||||
|
def test_aiiamoe_save_load():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIAmoe(config, num_experts=5)
|
||||||
|
save_path = "test_aiiamoe_save_load"
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
model.save(save_path)
|
||||||
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||||
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
loaded_model = AIIAmoe.load(save_path)
|
||||||
|
|
||||||
|
# Check if the loaded model is an instance of AIIAmoe
|
||||||
|
assert isinstance(loaded_model, AIIAmoe)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
os.remove(os.path.join(save_path, "model.pth"))
|
||||||
|
os.remove(os.path.join(save_path, "config.json"))
|
||||||
|
os.rmdir(save_path)
|
||||||
|
|
||||||
|
def test_aiiachunked_creation():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIAchunked(config)
|
||||||
|
assert isinstance(model, AIIAchunked)
|
||||||
|
|
||||||
|
def test_aiiachunked_save_load():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIAchunked(config)
|
||||||
|
save_path = "test_aiiachunked_save_load"
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
model.save(save_path)
|
||||||
|
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
||||||
|
assert os.path.exists(os.path.join(save_path, "config.json"))
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
loaded_model = AIIAchunked.load(save_path)
|
||||||
|
|
||||||
|
# Check if the loaded model is an instance of AIIAchunked
|
||||||
|
assert isinstance(loaded_model, AIIAchunked)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
os.remove(os.path.join(save_path, "model.pth"))
|
||||||
|
os.remove(os.path.join(save_path, "config.json"))
|
||||||
|
os.rmdir(save_path)
|
|
@ -0,0 +1,75 @@
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import pytest
|
||||||
|
import torch.nn as nn
|
||||||
|
from aiia import AIIAConfig
|
||||||
|
|
||||||
|
def test_aiia_config_initialization():
|
||||||
|
config = AIIAConfig()
|
||||||
|
assert config.model_name == "AIIA"
|
||||||
|
assert config.kernel_size == 3
|
||||||
|
assert config.activation_function == "GELU"
|
||||||
|
assert config.hidden_size == 512
|
||||||
|
assert config.num_hidden_layers == 12
|
||||||
|
assert config.num_channels == 3
|
||||||
|
assert config.learning_rate == 5e-5
|
||||||
|
|
||||||
|
def test_aiia_config_custom_initialization():
|
||||||
|
config = AIIAConfig(
|
||||||
|
model_name="CustomModel",
|
||||||
|
kernel_size=5,
|
||||||
|
activation_function="ReLU",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=8,
|
||||||
|
num_channels=1,
|
||||||
|
learning_rate=1e-4
|
||||||
|
)
|
||||||
|
assert config.model_name == "CustomModel"
|
||||||
|
assert config.kernel_size == 5
|
||||||
|
assert config.activation_function == "ReLU"
|
||||||
|
assert config.hidden_size == 1024
|
||||||
|
assert config.num_hidden_layers == 8
|
||||||
|
assert config.num_channels == 1
|
||||||
|
assert config.learning_rate == 1e-4
|
||||||
|
|
||||||
|
def test_aiia_config_invalid_activation_function():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AIIAConfig(activation_function="InvalidFunction")
|
||||||
|
|
||||||
|
def test_aiia_config_to_dict():
|
||||||
|
config = AIIAConfig()
|
||||||
|
config_dict = config.to_dict()
|
||||||
|
assert isinstance(config_dict, dict)
|
||||||
|
assert config_dict["model_name"] == "AIIA"
|
||||||
|
assert config_dict["kernel_size"] == 3
|
||||||
|
|
||||||
|
def test_aiia_config_save_and_load():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config = AIIAConfig(model_name="TempModel")
|
||||||
|
save_path = os.path.join(tmpdir, "config")
|
||||||
|
config.save(save_path)
|
||||||
|
|
||||||
|
loaded_config = AIIAConfig.load(save_path)
|
||||||
|
assert loaded_config.model_name == "TempModel"
|
||||||
|
assert loaded_config.kernel_size == 3
|
||||||
|
assert loaded_config.activation_function == "GELU"
|
||||||
|
|
||||||
|
def test_aiia_config_save_and_load_with_custom_attributes():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config = AIIAConfig(model_name="TempModel", custom_attr="value")
|
||||||
|
save_path = os.path.join(tmpdir, "config")
|
||||||
|
config.save(save_path)
|
||||||
|
|
||||||
|
loaded_config = AIIAConfig.load(save_path)
|
||||||
|
assert loaded_config.model_name == "TempModel"
|
||||||
|
assert loaded_config.custom_attr == "value"
|
||||||
|
|
||||||
|
def test_aiia_config_save_and_load_with_nested_attributes():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config = AIIAConfig(model_name="TempModel", nested={"key": "value"})
|
||||||
|
save_path = os.path.join(tmpdir, "config")
|
||||||
|
config.save(save_path)
|
||||||
|
|
||||||
|
loaded_config = AIIAConfig.load(save_path)
|
||||||
|
assert loaded_config.model_name == "TempModel"
|
||||||
|
assert loaded_config.nested == {"key": "value"}
|
|
@ -0,0 +1,308 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
|
||||||
|
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# Test the ProjectionHead class
|
||||||
|
def test_projection_head():
|
||||||
|
head = ProjectionHead(hidden_size=512)
|
||||||
|
x = torch.randn(1, 512, 32, 32)
|
||||||
|
|
||||||
|
# Test denoise task
|
||||||
|
output_denoise = head(x, task='denoise')
|
||||||
|
assert output_denoise.shape == (1, 3, 32, 32)
|
||||||
|
|
||||||
|
# Test rotate task
|
||||||
|
output_rotate = head(x, task='rotate')
|
||||||
|
assert output_rotate.shape == (1, 4)
|
||||||
|
|
||||||
|
# Test the Pretrainer class initialization
|
||||||
|
def test_pretrainer_initialization():
|
||||||
|
config = AIIAConfig()
|
||||||
|
model = AIIABase(config=config)
|
||||||
|
pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config)
|
||||||
|
assert pretrainer.device in ["cuda", "cpu"]
|
||||||
|
assert isinstance(pretrainer.projection_head, ProjectionHead)
|
||||||
|
assert isinstance(pretrainer.optimizer, torch.optim.AdamW)
|
||||||
|
|
||||||
|
# Test the safe_collate method
|
||||||
|
def test_safe_collate():
|
||||||
|
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
||||||
|
batch = [
|
||||||
|
(torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'),
|
||||||
|
(torch.randn(3, 32, 32), torch.tensor(1), 'rotate')
|
||||||
|
]
|
||||||
|
|
||||||
|
collated_batch = pretrainer.safe_collate(batch)
|
||||||
|
assert 'denoise' in collated_batch
|
||||||
|
assert 'rotate' in collated_batch
|
||||||
|
|
||||||
|
# Test the _process_batch method
|
||||||
|
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
||||||
|
def test_process_batch(mock_process_batch):
|
||||||
|
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
||||||
|
batch_data = {
|
||||||
|
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
|
||||||
|
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
|
||||||
|
}
|
||||||
|
criterion_denoise = MagicMock()
|
||||||
|
criterion_rotate = MagicMock()
|
||||||
|
|
||||||
|
mock_process_batch.return_value = 0.5
|
||||||
|
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||||
|
assert loss == 0.5
|
||||||
|
|
||||||
|
@patch('pandas.concat')
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
|
@patch('os.path.join', return_value='mocked/path/model.pt')
|
||||||
|
@patch('builtins.print') # Add this to mock the print function
|
||||||
|
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test the train method under normal conditions with comprehensive verification."""
|
||||||
|
# Setup test data and mocks
|
||||||
|
real_df = pd.DataFrame({
|
||||||
|
'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]
|
||||||
|
})
|
||||||
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
|
# Mock the model and related components
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_projection_head = MagicMock()
|
||||||
|
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = mock_projection_head
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
# Setup dataset paths and mock batch data
|
||||||
|
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
||||||
|
output_path = "AIIA_test"
|
||||||
|
|
||||||
|
# Create mock batch data with proper structure
|
||||||
|
mock_batch_data = {
|
||||||
|
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
|
||||||
|
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Configure batch loss
|
||||||
|
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
|
||||||
|
loader_instance = MagicMock()
|
||||||
|
loader_instance.train_loader = [mock_batch_data]
|
||||||
|
loader_instance.val_loader = [mock_batch_data]
|
||||||
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
|
# Execute training with patched methods
|
||||||
|
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \
|
||||||
|
patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \
|
||||||
|
patch.object(Pretrainer, 'save_losses') as mock_save_losses, \
|
||||||
|
patch('builtins.open', mock_open()):
|
||||||
|
|
||||||
|
pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2)
|
||||||
|
|
||||||
|
# Verify method calls
|
||||||
|
assert mock_read_parquet.call_count == len(dataset_paths)
|
||||||
|
assert mock_process_batch.call_count == 2
|
||||||
|
assert mock_validate.call_count == 2
|
||||||
|
|
||||||
|
# Check for "Best model saved!" instead of model.save()
|
||||||
|
mock_print.assert_any_call("Best model saved!")
|
||||||
|
|
||||||
|
mock_save_losses.assert_called_once()
|
||||||
|
|
||||||
|
# Verify state changes
|
||||||
|
assert len(pretrainer.train_losses) == 2
|
||||||
|
assert pretrainer.train_losses == [0.5, 0.5]
|
||||||
|
|
||||||
|
|
||||||
|
# Error cases
|
||||||
|
def test_train_no_dataset_paths():
|
||||||
|
"""Test ValueError when no dataset paths are provided."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No dataset paths provided"):
|
||||||
|
pretrainer.train([])
|
||||||
|
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
def test_train_all_datasets_fail(mock_read_parquet):
|
||||||
|
"""Test handling when all datasets fail to load."""
|
||||||
|
mock_read_parquet.side_effect = Exception("Failed to load dataset")
|
||||||
|
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No valid datasets could be loaded"):
|
||||||
|
pretrainer.train(dataset_paths)
|
||||||
|
|
||||||
|
# Edge cases
|
||||||
|
@patch('pandas.concat')
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
|
def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test behavior with empty data loaders."""
|
||||||
|
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||||
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
|
loader_instance = MagicMock()
|
||||||
|
loader_instance.train_loader = [] # Empty train loader
|
||||||
|
loader_instance.val_loader = [] # Empty val loader
|
||||||
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(Pretrainer, 'save_losses') as mock_save_losses:
|
||||||
|
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
|
||||||
|
|
||||||
|
# Verify empty loader behavior
|
||||||
|
assert len(pretrainer.train_losses) == 1
|
||||||
|
assert pretrainer.train_losses[0] == 0.0
|
||||||
|
mock_save_losses.assert_called_once()
|
||||||
|
|
||||||
|
@patch('pandas.concat')
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
|
def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test behavior when batch_data is None."""
|
||||||
|
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||||
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
|
loader_instance = MagicMock()
|
||||||
|
loader_instance.train_loader = [None] # Loader returns None
|
||||||
|
loader_instance.val_loader = []
|
||||||
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
|
||||||
|
patch.object(Pretrainer, 'save_losses'):
|
||||||
|
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
|
||||||
|
|
||||||
|
# Verify None batch handling
|
||||||
|
mock_process_batch.assert_not_called()
|
||||||
|
assert pretrainer.train_losses[0] == 0.0
|
||||||
|
|
||||||
|
# Parameter variations
|
||||||
|
@patch('pandas.concat')
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
|
def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test that custom parameters are properly passed through."""
|
||||||
|
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||||
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
|
loader_instance = MagicMock()
|
||||||
|
loader_instance.train_loader = []
|
||||||
|
loader_instance.val_loader = []
|
||||||
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
# Custom parameters
|
||||||
|
custom_output_path = "custom/output/path"
|
||||||
|
custom_column = "custom_column"
|
||||||
|
custom_batch_size = 16
|
||||||
|
custom_sample_size = 5000
|
||||||
|
|
||||||
|
with patch.object(Pretrainer, 'save_losses'):
|
||||||
|
pretrainer.train(
|
||||||
|
['path/to/dataset.parquet'],
|
||||||
|
output_path=custom_output_path,
|
||||||
|
column=custom_column,
|
||||||
|
batch_size=custom_batch_size,
|
||||||
|
sample_size=custom_sample_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify custom parameters were used
|
||||||
|
mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size)
|
||||||
|
assert mock_data_loader.call_args[1]['column'] == custom_column
|
||||||
|
assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@patch('pandas.concat')
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
|
@patch('builtins.print') # Add this to mock the print function
|
||||||
|
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test that model is saved only when validation loss improves."""
|
||||||
|
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||||
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
|
# Create mock batch data with proper structure
|
||||||
|
mock_batch_data = {
|
||||||
|
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
|
||||||
|
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
loader_instance = MagicMock()
|
||||||
|
loader_instance.train_loader = [mock_batch_data]
|
||||||
|
loader_instance.val_loader = [mock_batch_data]
|
||||||
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
# Initialize the best validation loss
|
||||||
|
pretrainer.best_val_loss = float('inf')
|
||||||
|
|
||||||
|
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
|
||||||
|
|
||||||
|
# Test improving validation loss
|
||||||
|
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
|
||||||
|
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
|
||||||
|
patch.object(Pretrainer, 'save_losses'):
|
||||||
|
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
|
||||||
|
|
||||||
|
# Check for "Best model saved!" 3 times
|
||||||
|
assert mock_print.call_args_list.count(call("Best model saved!")) == 3
|
||||||
|
|
||||||
|
# Reset for next test
|
||||||
|
mock_print.reset_mock()
|
||||||
|
pretrainer.train_losses = []
|
||||||
|
|
||||||
|
# Reset best validation loss for the second test
|
||||||
|
pretrainer.best_val_loss = float('inf')
|
||||||
|
|
||||||
|
# Test fluctuating validation loss
|
||||||
|
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
|
||||||
|
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
|
||||||
|
patch.object(Pretrainer, 'save_losses'):
|
||||||
|
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
|
||||||
|
|
||||||
|
# Should print "Best model saved!" only on first and third epochs
|
||||||
|
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
||||||
|
def test_validate(mock_process_batch):
|
||||||
|
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
||||||
|
val_loader = [MagicMock()]
|
||||||
|
criterion_denoise = MagicMock()
|
||||||
|
criterion_rotate = MagicMock()
|
||||||
|
|
||||||
|
mock_process_batch.return_value = torch.tensor(0.5)
|
||||||
|
loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate)
|
||||||
|
assert loss == 0.5
|
||||||
|
|
||||||
|
# Test the save_losses method
|
||||||
|
@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses')
|
||||||
|
def test_save_losses(mock_save_losses):
|
||||||
|
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
||||||
|
pretrainer.train_losses = [0.1, 0.2]
|
||||||
|
pretrainer.val_losses = [0.3, 0.4]
|
||||||
|
|
||||||
|
csv_file = 'losses.csv'
|
||||||
|
pretrainer.save_losses(csv_file)
|
||||||
|
mock_save_losses.assert_called_once_with(csv_file)
|
Loading…
Reference in New Issue