finetune_class #1
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from aiia import AIIA, AIIAConfig, AIIABase
|
from aiia import AIIA, AIIAConfig, AIIABase
|
||||||
|
from config import UpsamplerConfig
|
||||||
|
|
||||||
# Upsampler model that uses the configuration from the base model.
|
# Upsampler model that uses the configuration from the base model.
|
||||||
class Upsampler(AIIA):
|
class Upsampler(AIIA):
|
||||||
|
@ -9,7 +9,7 @@ class Upsampler(AIIA):
|
||||||
# Assume that base_model.config is an instance of UpsamplerConfig.
|
# Assume that base_model.config is an instance of UpsamplerConfig.
|
||||||
super().__init__(base_model.config)
|
super().__init__(base_model.config)
|
||||||
self.base_model = base_model
|
self.base_model = base_model
|
||||||
|
self.config = UpsamplerConfig(self.base_model.config)
|
||||||
# Create the upsample layer using values from the configuration.
|
# Create the upsample layer using values from the configuration.
|
||||||
self.upsample = nn.Upsample(
|
self.upsample = nn.Upsample(
|
||||||
scale_factor=self.config.upsample_scale,
|
scale_factor=self.config.upsample_scale,
|
||||||
|
|
Loading…
Reference in New Issue