Merge pull request 'first_pretraining' (#1) from first_pretraining into develop
Reviewed-on: Fabel/AIIA#1
This commit is contained in:
commit
6f70ebdf53
|
@ -0,0 +1,4 @@
|
|||
# Import submodules
|
||||
from .model import AIIA, AIIAEncoder
|
||||
from .data import AIIADataLoader
|
||||
from .model.config import AIIAConfig
|
|
@ -0,0 +1,206 @@
|
|||
import io
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
class FilePathLoader:
|
||||
def __init__(self, dataset, file_path_column="file_path", label_column=None):
|
||||
self.dataset = dataset
|
||||
self.file_path_column = file_path_column
|
||||
self.label_column = label_column
|
||||
self.successful_count = 0
|
||||
self.skipped_count = 0
|
||||
|
||||
if self.file_path_column not in dataset.column_names:
|
||||
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
||||
|
||||
def _get_image(self, item):
|
||||
try:
|
||||
path = item[self.file_path_column]
|
||||
image = Image.open(path).convert("RGB")
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Error loading image from {path}: {e}")
|
||||
return None
|
||||
|
||||
def get_item(self, idx):
|
||||
item = self.dataset[idx]
|
||||
image = self._get_image(item)
|
||||
if image is not None:
|
||||
self.successful_count += 1
|
||||
if self.label_column is not None:
|
||||
label = item.get(self.label_column)
|
||||
return (image, label)
|
||||
else:
|
||||
return (image,)
|
||||
else:
|
||||
self.skipped_count += 1
|
||||
return None
|
||||
|
||||
def print_summary(self):
|
||||
print(f"Successfully converted {self.successful_count} images.")
|
||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
class JPGImageLoader:
|
||||
def __init__(self, dataset, bytes_column="jpg", label_column=None):
|
||||
self.dataset = dataset
|
||||
self.bytes_column = bytes_column
|
||||
self.label_column = label_column
|
||||
self.successful_count = 0
|
||||
self.skipped_count = 0
|
||||
|
||||
if self.bytes_column not in dataset.column_names:
|
||||
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
||||
|
||||
def _get_image(self, item):
|
||||
try:
|
||||
bytes_data = item[self.bytes_column]
|
||||
img_bytes = io.BytesIO(bytes_data)
|
||||
image = Image.open(img_bytes).convert("RGB")
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Error loading image from bytes: {e}")
|
||||
return None
|
||||
|
||||
def get_item(self, idx):
|
||||
item = self.dataset[idx]
|
||||
image = self._get_image(item)
|
||||
if image is not None:
|
||||
self.successful_count += 1
|
||||
if self.label_column is not None:
|
||||
label = item.get(self.label_column)
|
||||
return (image, label)
|
||||
else:
|
||||
return (image,)
|
||||
else:
|
||||
self.skipped_count += 1
|
||||
return None
|
||||
|
||||
def print_summary(self):
|
||||
print(f"Successfully converted {self.successful_count} images.")
|
||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
|
||||
class AIIADataLoader(DataLoader):
|
||||
def __init__(self, dataset,
|
||||
batch_size=32,
|
||||
val_split=0.2,
|
||||
seed=42,
|
||||
column="file_path",
|
||||
label_column=None):
|
||||
super().__init__()
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
|
||||
# Determine which loader to use based on the dataset's content
|
||||
# Check if any entry in bytes_column is a bytes or bytestring type
|
||||
is_bytes_or_bytestring = any(
|
||||
isinstance(value, (bytes, memoryview))
|
||||
for value in dataset[column].dropna().head(1).astype(str)
|
||||
)
|
||||
|
||||
if is_bytes_or_bytestring:
|
||||
self.loader = JPGImageLoader(
|
||||
dataset,
|
||||
bytes_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
else:
|
||||
# Check if file_path column contains valid image file paths (at least one entry)
|
||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||
|
||||
# Regex pattern for matching image file paths (adjust as needed)
|
||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
|
||||
|
||||
if any(
|
||||
re.match(filepath_pattern, path, flags=re.IGNORECASE)
|
||||
for path in sample_paths
|
||||
):
|
||||
self.loader = FilePathLoader(
|
||||
dataset,
|
||||
file_path_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
else:
|
||||
# If neither condition is met, default to JPGImageLoader (assuming bytes are stored as strings)
|
||||
self.loader = JPGImageLoader(
|
||||
dataset,
|
||||
bytes_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
|
||||
# Get all items
|
||||
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
|
||||
|
||||
# Split into train and validation sets
|
||||
train_indices, val_indices = self._split_data()
|
||||
|
||||
# Create datasets for training and validation
|
||||
self.train_dataset = self._create_subset(train_indices)
|
||||
self.val_dataset = self._create_subset(val_indices)
|
||||
|
||||
def _split_data(self):
|
||||
if len(self.items) == 0:
|
||||
return [], []
|
||||
|
||||
tasks = [item[1] if len(item) > 1 and hasattr(item, '__getitem__') else None for item in self.items]
|
||||
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else []
|
||||
|
||||
train_indices = []
|
||||
val_indices = []
|
||||
|
||||
for task in unique_tasks:
|
||||
task_indices = [i for i, t in enumerate(tasks) if t == task]
|
||||
n_val = int(len(task_indices) * self.val_split)
|
||||
|
||||
random.shuffle(task_indices)
|
||||
|
||||
val_indices.extend(task_indices[:n_val])
|
||||
train_indices.extend(task_indices[n_val:])
|
||||
|
||||
return train_indices, val_indices
|
||||
|
||||
def _create_subset(self, indices):
|
||||
subset_items = [self.items[i] for i in indices]
|
||||
return AIIADataset(subset_items)
|
||||
|
||||
class AIIADataset(torch.utils.data.Dataset):
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.items[idx]
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image, label = item
|
||||
return (image, label)
|
||||
elif isinstance(item, tuple) and len(item) == 3:
|
||||
image, task, label = item
|
||||
# Handle tasks accordingly (e.g., apply different augmentations)
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
noisy_img = image + torch.randn_like(image) * noise_std
|
||||
target = image
|
||||
return (noisy_img, target, task)
|
||||
elif task == 'rotate':
|
||||
angles = [0, 90, 180, 270]
|
||||
angle = random.choice(angles)
|
||||
rotated_img = transforms.functional.rotate(image, angle)
|
||||
target = torch.tensor(angle).long()
|
||||
return (rotated_img, target, task)
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {task}")
|
||||
else:
|
||||
# Handle single images without labels or tasks
|
||||
if isinstance(item, Image.Image):
|
||||
return item
|
||||
else:
|
||||
raise ValueError("Invalid item format.")
|
|
@ -0,0 +1 @@
|
|||
from .DataLoader import AIIADataLoader
|
|
@ -0,0 +1,216 @@
|
|||
from config import AIIAConfig
|
||||
from torch import nn
|
||||
import torch
|
||||
import os
|
||||
import copy # Add this for deep copying
|
||||
|
||||
class AIIA(nn.Module):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super(AIIA, self).__init__()
|
||||
# Create a deep copy of the configuration to avoid sharing
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
# Update the config with any additional keyword arguments
|
||||
for key, value in kwargs.items():
|
||||
setattr(self.config, key, value)
|
||||
|
||||
def save(self, path: str):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save(self.state_dict(), f"{path}/model.pth")
|
||||
self.config.save(path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
config = AIIAConfig.load(path)
|
||||
model = cls(config)
|
||||
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
||||
return model
|
||||
|
||||
class AIIABaseShared(AIIA):
|
||||
def __init__(self, config: AIIAConfig, num_shared_layers=1, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = copy.deepcopy(config)
|
||||
self.config.num_shared_layers = num_shared_layers
|
||||
# Update config with new parameters if provided
|
||||
for key, value in kwargs.items():
|
||||
setattr(self.config, key, value)
|
||||
|
||||
# Shared layers (early stages) use the same kernel
|
||||
self.shared_layers = nn.ModuleList()
|
||||
for _ in range(self.config.num_shared_layers):
|
||||
layer = nn.Conv2d(
|
||||
self.config.num_channels,
|
||||
self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1
|
||||
)
|
||||
# Initialize with shared weights if it's the first layer
|
||||
if len(self.shared_layers) == 0:
|
||||
self.shared_weights = layer.weight
|
||||
self.shared_biases = nn.ParameterList([
|
||||
nn.Parameter(torch.zeros(self.config.hidden_size))
|
||||
for _ in range(self.config.num_shared_layers)
|
||||
])
|
||||
else:
|
||||
layer.weight = self.shared_weights
|
||||
# Assign separate biases
|
||||
layer.bias = self.shared_biases[len(self.shared_layers)]
|
||||
self.shared_layers.append(layer)
|
||||
|
||||
# Unique layers (later stages) have their own weights and biases
|
||||
self.unique_layers = nn.ModuleList()
|
||||
in_channels = self.config.hidden_size
|
||||
for _ in range(self.config.num_shared_layers):
|
||||
self.unique_layers.append(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size,
|
||||
padding=1
|
||||
)
|
||||
)
|
||||
|
||||
# Activation and pooling layers
|
||||
self.activation_function = getattr(nn, self.config.activation_function)()
|
||||
self.max_pool = nn.MaxPool2d(self.config.kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.shared_layers:
|
||||
x = layer(x)
|
||||
x = self.activation_function(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
for layer in self.unique_layers:
|
||||
x = layer(x)
|
||||
x = self.activation_function(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AIIABase(AIIA):
|
||||
def __init__(self, config: AIIAConfig, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Initialize layers based on configuration
|
||||
layers = []
|
||||
in_channels = self.config.num_channels
|
||||
|
||||
for _ in range(self.config.num_hidden_layers):
|
||||
layers.extend([
|
||||
nn.Conv2d(in_channels, self.config.hidden_size,
|
||||
kernel_size=self.config.kernel_size, padding=1),
|
||||
getattr(nn, self.config.activation_function)(),
|
||||
nn.MaxPool2d(kernel_size=2)
|
||||
])
|
||||
in_channels = self.config.hidden_size
|
||||
|
||||
self.cnn = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cnn(x)
|
||||
|
||||
class AIIAExpert(AIIA):
|
||||
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Initialize base CNN with configuration and chosen base class
|
||||
if issubclass(base_class, AIIABase):
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
elif issubclass(base_class, AIIABaseShared):
|
||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||
else:
|
||||
raise ValueError("Invalid base class")
|
||||
|
||||
class AIIAmoe(AIIA):
|
||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Update config with new parameters if provided
|
||||
self.config.num_experts = num_experts
|
||||
|
||||
# Initialize multiple experts using chosen base class
|
||||
self.experts = nn.ModuleList([
|
||||
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
||||
for _ in range(self.config.num_experts)
|
||||
])
|
||||
|
||||
# Create gating network
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(self.config.hidden_size, self.config.num_experts),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
||||
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
|
||||
merged_output = torch.sum(
|
||||
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1
|
||||
)
|
||||
return merged_output
|
||||
|
||||
class AIIAchunked(AIIA):
|
||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Update config with new parameters if provided
|
||||
self.config.patch_size = patch_size
|
||||
|
||||
# Initialize base CNN for processing each patch using the specified base class
|
||||
if issubclass(base_class, AIIABase):
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
|
||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||
else:
|
||||
raise ValueError("Invalid base class")
|
||||
|
||||
def forward(self, x):
|
||||
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
|
||||
patch_outputs = []
|
||||
|
||||
for p in torch.split(patches, 1, dim=2):
|
||||
p = p.squeeze(2)
|
||||
po = self.base_cnn(p)
|
||||
patch_outputs.append(po)
|
||||
|
||||
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
||||
class AIIArecursive(AIIA):
|
||||
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
|
||||
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Pass recursion_depth as a kwarg to the config
|
||||
self.config.recursion_depth = recursion_depth
|
||||
|
||||
# Initialize chunked CNN with updated config
|
||||
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
|
||||
|
||||
def forward(self, x, depth=0):
|
||||
if depth == self.recursion_depth:
|
||||
return self.chunked_cnn(x)
|
||||
else:
|
||||
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
|
||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
|
||||
processed_patches = []
|
||||
|
||||
for p in torch.split(patches, 1, dim=2):
|
||||
p = p.squeeze(2)
|
||||
pp = self.forward(p, depth + 1)
|
||||
processed_patches.append(pp)
|
||||
|
||||
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
||||
config = AIIAConfig()
|
||||
model2 = AIIABaseShared(config)
|
||||
|
||||
model2.save("shared")
|
|
@ -0,0 +1,2 @@
|
|||
from .config import AIIAConfig
|
||||
from .Model import AIIA, AIIABase, AIIAchunked, AIIAExpert, AIIAmoe, AIIAresursive
|
|
@ -0,0 +1,53 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
class AIIAConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "AIIA",
|
||||
kernel_size: int = 5,
|
||||
activation_function: str = "GELU",
|
||||
hidden_size: int = 512,
|
||||
num_hidden_layers: int = 12,
|
||||
num_channels: int = 3,
|
||||
learning_rate: float = 5e-5,
|
||||
**kwargs
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.kernel_size = kernel_size
|
||||
self.activation_function = activation_function
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_channels = num_channels
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
# Store additional keyword arguments as attributes
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
@property
|
||||
def activation_function(self):
|
||||
return self._activation_function
|
||||
|
||||
@activation_function.setter
|
||||
def activation_function(self, value):
|
||||
attr = getattr(nn, value, None)
|
||||
if attr is None or (not callable(attr) and not isinstance(attr, type(nn.Module))):
|
||||
valid_funcs = [func for func in dir(nn) if callable(getattr(nn, func)) or isinstance(getattr(nn, func), type(nn.Module))]
|
||||
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
|
||||
self._activation_function = value
|
||||
|
||||
def save(self, file_path):
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
with open(f"{file_path}/config.json", 'w') as f:
|
||||
json.dump(vars(self), f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def load(cls, file_path):
|
||||
with open(f"{file_path}/config.json", 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
return cls(**config_dict)
|
|
@ -0,0 +1,149 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import os
|
||||
import random
|
||||
import pandas as pd
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.model import AIIABase
|
||||
from aiia.data.DataLoader import AIIADataLoader
|
||||
|
||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||
# Merge the two parquet files
|
||||
df1 = pd.read_parquet(data_path1)
|
||||
df2 = pd.read_parquet(data_path2)
|
||||
merged_df = pd.concat([df1, df2], ignore_index=True)
|
||||
|
||||
# Create a new AIIAConfig instance
|
||||
config = AIIAConfig(
|
||||
model_name="AIIA-512x",
|
||||
hidden_size=512,
|
||||
num_hidden_layers=12,
|
||||
kernel_size=5,
|
||||
learning_rate=5e-5
|
||||
)
|
||||
|
||||
# Initialize the base model
|
||||
model = AIIABase(config)
|
||||
|
||||
# Create dataset loader with merged data
|
||||
train_dataset = AIIADataLoader(
|
||||
merged_df,
|
||||
batch_size=32,
|
||||
val_split=0.2,
|
||||
seed=42,
|
||||
column="file_path",
|
||||
label_column=None
|
||||
)
|
||||
|
||||
# Create separate dataloaders for training and validation sets
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset.train_dataset,
|
||||
batch_size=train_dataset.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
train_dataset.val_ataset,
|
||||
batch_size=train_dataset.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4
|
||||
)
|
||||
|
||||
# Initialize loss functions and optimizer
|
||||
criterion_denoise = nn.MSELoss()
|
||||
criterion_rotate = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model.to(device)
|
||||
|
||||
best_val_loss = float('inf')
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||
print("-" * 20)
|
||||
|
||||
# Training phase
|
||||
model.train()
|
||||
total_train_loss = 0.0
|
||||
denoise_train_loss = 0.0
|
||||
rotate_train_loss = 0.0
|
||||
|
||||
for batch in train_dataloader:
|
||||
images, targets, tasks = zip(*batch)
|
||||
|
||||
if device == "cuda":
|
||||
images = [img.cuda() for img in images]
|
||||
targets = [t.cuda() for t in targets]
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Process each sample individually since tasks can vary
|
||||
outputs = []
|
||||
total_loss = 0.0
|
||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
||||
output = model(image.unsqueeze(0))
|
||||
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(output.squeeze(), target)
|
||||
elif task == 'rotate':
|
||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
||||
|
||||
total_loss += loss
|
||||
outputs.append(output)
|
||||
|
||||
avg_loss = total_loss / len(images)
|
||||
avg_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_train_loss += avg_loss.item()
|
||||
# Separate losses for reporting (you'd need to track this based on tasks)
|
||||
|
||||
avg_total_train_loss = total_train_loss / len(train_dataloader)
|
||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
val_losses = []
|
||||
for batch in val_dataloader:
|
||||
images, targets, tasks = zip(*batch)
|
||||
|
||||
if device == "cuda":
|
||||
images = [img.cuda() for img in images]
|
||||
targets = [t.cuda() for t in targets]
|
||||
|
||||
outputs = []
|
||||
total_loss = 0.0
|
||||
for i, (image, target, task) in enumerate(zip(images, targets, tasks)):
|
||||
output = model(image.unsqueeze(0))
|
||||
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(output.squeeze(), target)
|
||||
elif task == 'rotate':
|
||||
loss = criterion_rotate(output.view(-1, len(set(outputs))), target)
|
||||
|
||||
total_loss += loss
|
||||
outputs.append(output)
|
||||
|
||||
avg_val_loss = total_loss / len(images)
|
||||
val_losses.append(avg_val_loss.item())
|
||||
|
||||
avg_val_loss = sum(val_losses) / len(val_dataloader)
|
||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
|
||||
# Save the best model
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
model.save("BASEv0.1")
|
||||
print("Best model saved!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet"
|
||||
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||
pretrain_model(data_path1, data_path2, num_epochs=8)
|
Loading…
Reference in New Issue