corrected the base model_name
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Failing after 29s
Details
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Failing after 29s
Details
This commit is contained in:
parent
9aa66cbe33
commit
a7bab86d85
|
@ -17,24 +17,24 @@ class aiuNN(PreTrainedModel):
|
||||||
|
|
||||||
# Enhanced approach
|
# Enhanced approach
|
||||||
scale_factor = self.config.upsample_scale
|
scale_factor = self.config.upsample_scale
|
||||||
out_channels = self.base_model.config.num_channels * (scale_factor ** 2)
|
out_channels = self.aiia_model.config.num_channels * (scale_factor ** 2)
|
||||||
self.pixel_shuffle_conv = nn.Conv2d(
|
self.pixel_shuffle_conv = nn.Conv2d(
|
||||||
in_channels=self.base_model.config.hidden_size,
|
in_channels=self.aiia_model.config.hidden_size,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=self.base_model.config.kernel_size,
|
kernel_size=self.aiia_model.config.kernel_size,
|
||||||
padding=1
|
padding=1
|
||||||
)
|
)
|
||||||
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
||||||
|
|
||||||
def load_base_model(self, base_model: PreTrainedModel):
|
def load_base_model(self, base_model: PreTrainedModel):
|
||||||
self.base_model = base_model
|
self.aiia_model = base_model
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.base_model is None:
|
if self.aiia_model is None:
|
||||||
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
||||||
|
|
||||||
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
|
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
|
||||||
base_output = self.base_model(x)
|
base_output = self.aiia_model(x)
|
||||||
if isinstance(base_output, tuple):
|
if isinstance(base_output, tuple):
|
||||||
x = base_output[0]
|
x = base_output[0]
|
||||||
elif isinstance(base_output, dict):
|
elif isinstance(base_output, dict):
|
||||||
|
|
Loading…
Reference in New Issue