Merge pull request 'corrected MOTrainer' (#27) from feat/change_mef into main
Reviewed-on: #27
This commit is contained in:
commit
dc12ec50ff
|
@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
[project]
|
[project]
|
||||||
name = "aiunn"
|
name = "aiunn"
|
||||||
version = "0.4.0"
|
version = "0.4.1"
|
||||||
description = "Finetuner for image upscaling using AIIA"
|
description = "Finetuner for image upscaling using AIIA"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="aiunn",
|
name="aiunn",
|
||||||
version="0.4.0",
|
version="0.4.1",
|
||||||
packages=find_packages(where="src"),
|
packages=find_packages(where="src"),
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
|
|
@ -4,4 +4,4 @@ from .upsampler.aiunn import aiuNN
|
||||||
from .upsampler.config import aiuNNConfig
|
from .upsampler.config import aiuNNConfig
|
||||||
from .inference.inference import aiuNNInference
|
from .inference.inference import aiuNNInference
|
||||||
|
|
||||||
__version__ = "0.4.0"
|
__version__ = "0.4.1"
|
|
@ -177,6 +177,7 @@ class MemoryOptimizedTrainer(aiuNNTrainer):
|
||||||
|
|
||||||
with autocast(device_type=self.device.type):
|
with autocast(device_type=self.device.type):
|
||||||
outputs = self.model(low_res)
|
outputs = self.model(low_res)
|
||||||
|
outputs = outputs.clone()
|
||||||
loss = self.criterion(outputs, high_res)
|
loss = self.criterion(outputs, high_res)
|
||||||
|
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
@ -252,10 +253,10 @@ class MemoryOptimizedTrainer(aiuNNTrainer):
|
||||||
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
||||||
low_res.requires_grad_()
|
low_res.requires_grad_()
|
||||||
outputs = checkpoint(self.model, low_res)
|
outputs = checkpoint(self.model, low_res)
|
||||||
outputs = outputs.clone() # <-- Clone added here
|
outputs = outputs.clone()
|
||||||
else:
|
else:
|
||||||
outputs = self.model(low_res)
|
outputs = self.model(low_res)
|
||||||
outputs = outputs.clone() # <-- Clone added here
|
outputs = outputs.clone()
|
||||||
loss = self.criterion(outputs, high_res)
|
loss = self.criterion(outputs, high_res)
|
||||||
|
|
||||||
# Scale loss for gradient accumulation
|
# Scale loss for gradient accumulation
|
||||||
|
|
Loading…
Reference in New Issue