/*
* Copyright (c) 2019, Alliance for Open Media. All rights reserved.
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include "gtest/gtest.h"
#include "config/av1_rtcd.h"
#include "aom_ports/aom_timer.h"
#include "av1/encoder/cnn.h"
#include "av1/encoder/partition_cnn_weights.h"
#include "test/acm_random.h"
#include "test/function_equivalence_test.h"
#include "test/util.h"
#define SQR(x) ((x) * (x))
// Best possible pixelwise guaranteed precision given each float has at most
// 3 specified decimals.
#define PIXELWISE_FLOAT_TOL 1 E-2
#define MSE_FLOAT_TOL 1 E-6
#define MSE_INT_TOL 0
// CNN convolve pixelwise error threshold for functional equivalence.
#define CNN_CONVOLVE_PIXELWISE_FLOAT_TOL 1 E-3 f
namespace {
class CNNTest : public ::testing::Test {
protected :
static void RunCNNTest(int image_width, int image_height, const float *input,
const float *expected, const CNN_CONFIG *cnn_config,
int in_stride, CNN_THREAD_DATA *thread_data,
double tolerance) {
int out_width, out_height, out_channels;
av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
&out_height, &out_channels);
const int out_size = out_width * out_height;
const int out_stride = out_width;
float *output_ =
(float *)aom_malloc(sizeof (*output_) * out_size * out_channels);
ASSERT_NE(output_, nullptr);
float *output[CNN_MAX_CHANNELS] = { nullptr };
for (int channel = 0 ; channel < out_channels; ++channel) {
output[channel] = output_ + (channel * out_size);
}
const int num_outputs = 1 ;
const int output_chs[1 ] = { out_channels };
const int output_strides[1 ] = { out_stride };
CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
output };
RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
thread_data, &output_struct, &expected, tolerance);
aom_free(output_);
}
static void RunMultiOutCNNTest(const float **input, int image_width,
int image_height, int in_stride,
const CNN_CONFIG *cnn_config,
CNN_THREAD_DATA *thread_data,
CNN_MULTI_OUT *output, const float **expected,
double tolerance) {
const int num_outputs = output->num_outputs;
const int *output_chs = output->output_channels;
int *out_widths = (int *)aom_calloc(num_outputs, sizeof (*out_widths));
int *out_heights = (int *)aom_calloc(num_outputs, sizeof (*out_heights));
int *not_used = (int *)aom_calloc(num_outputs, sizeof (*not_used));
ASSERT_NE(out_widths, nullptr);
ASSERT_NE(out_heights, nullptr);
ASSERT_NE(not_used, nullptr);
av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
out_heights, not_used);
ASSERT_TRUE(av1_cnn_predict(input, image_width, image_height, in_stride,
cnn_config, thread_data, output));
int channel_offset = 0 ;
for (int output_idx = 0 ; output_idx < num_outputs; output_idx++) {
const float *expected_out = expected[output_idx];
const int curr_output_chs = output_chs[output_idx];
const int out_size = out_widths[output_idx] * out_heights[output_idx];
double mse = 0 ;
int expected_ite = 0 ;
for (int channel = 0 ; channel < curr_output_chs; ++channel) {
const float *buf_out = output->output_buffer[channel_offset];
for (int i = 0 ; i < out_size; ++i) {
EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
PIXELWISE_FLOAT_TOL)
<< " output " << output_idx << " channel " << channel << " pixel "
<< expected_ite % out_size << ": " << expected_out[expected_ite]
<< "/" << buf_out[i] << std::endl;
mse += SQR(expected_out[expected_ite] - buf_out[i]);
expected_ite++;
}
channel_offset++;
}
mse /= (out_size * curr_output_chs);
EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
}
aom_free(out_widths);
aom_free(out_heights);
aom_free(not_used);
}
static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
float *bias) {
size_t weight_offset = 0 ;
size_t bias_offset = 0 ;
for (int layer = 0 ; layer < cnn_config->num_layers; ++layer) {
CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
layer_config->weights = weights + weight_offset;
layer_config->bias = bias + bias_offset;
weight_offset += layer_config->filter_width *
layer_config->filter_height * layer_config->in_channels *
layer_config->out_channels;
bias_offset += layer_config->out_channels;
ASSERT_NE(layer_config->weights, nullptr);
ASSERT_NE(layer_config->bias, nullptr);
}
}
};
} // namespace
TEST_F(CNNTest, TestMultilayerConvolution) {
int image_height = 16 ;
int image_width = 16 ;
int filter_height = 5 ;
int filter_width = 4 ;
float input[] = {
-3 , 1 , -3 , 2 , -2 , -2 , 2 , -2 , 1 , -2 , -3 , 1 , 2 , 2 , 2 , -2 , 0 , 1 , -1 ,
-3 , -1 , -1 , 1 , 0 , -3 , 1 , 0 , -1 , 1 , 0 , 0 , -3 , -3 , -3 , 0 , 2 , 1 , -1 ,
2 , 0 , 1 , -3 , -1 , 2 , 2 , 1 , -2 , 0 , -1 , 0 , -2 , -2 , -1 , 1 , 0 , 0 , 0 ,
-2 , -2 , -2 , 1 , 1 , -2 , 1 , 1 , -2 , -2 , 1 , -2 , -1 , -2 , -3 , 2 , -3 , -1 , 1 ,
0 , -2 , -2 , -2 , 1 , -2 , -2 , -1 , -1 , 2 , 2 , 2 , -1 , 1 , -3 , -3 , 0 , 2 , 0 ,
2 , 1 , -3 , -3 , 1 , 2 , 2 , 1 , -2 , -3 , 0 , -3 , 0 , -3 , -2 , 0 , 1 , 1 , 0 ,
-3 , 2 , -1 , 2 , 1 , 0 , 1 , -2 , 1 , -1 , -1 , 2 , 0 , -2 , -3 , 1 , 1 , -2 , -1 ,
-3 , -3 , -1 , 0 , -3 , -2 , 0 , 0 , 1 , 0 , -3 , -2 , -1 , 1 , 0 , 2 , 1 , 0 , -3 ,
-2 , -3 , -3 , -1 , 0 , -2 , 2 , -1 , -3 , 0 , -1 , -1 , 2 , 0 , -3 , -2 , -1 , 0 , 0 ,
1 , -2 , 1 , 2 , 1 , 2 , 2 , -3 , 2 , -1 , 0 , 0 , -1 , 0 , 2 , 2 , -1 , 2 , -2 ,
1 , 1 , -3 , -3 , 1 , -1 , -1 , -2 , 2 , -2 , -2 , 2 , -1 , -3 , 2 , -3 , 1 , -1 , -1 ,
-3 , 1 , -1 , 1 , 0 , -3 , -3 , 1 , -3 , -3 , 0 , 2 , 2 , -2 , -1 , 2 , 0 , 2 , 1 ,
-1 , -3 , 0 , 0 , -1 , -1 , 1 , 0 , 2 , 0 , -3 , 2 , 1 , 0 , 1 , -3 , 2 , -3 , -3 ,
-1 , -3 , -3 , 2 , 0 , 2 , -2 , 1 , -1 ,
};
float weights[] = {
-2 , 2 , -2 , 2 , -1 , -3 , 2 , 2 , 0 , 0 , -3 , -1 , -2 , -3 , 1 , -1 , 0 , 0 , 0 ,
2 , -2 , 2 , -2 , -3 , 1 , 1 , 1 , -3 , -1 , 0 , 1 , 2 , -2 , 0 , -1 , -3 , -1 , -2 ,
2 , -3 , -3 , 1 , -2 , -3 , 0 , 2 , 1 , -3 , -3 , -1 , -3 , -2 , -1 , -3 , -1 , -3 , -2 ,
-1 , -3 , -1 , -2 , -2 , -3 , 2 , 0 , -3 , 0 , -3 , -3 , 1 , -3 , -1 , 0 , -1 , 1 , 1 ,
-1 , 1 , -2 , 0 , 2 , 0 , -3 , 1 , -1 , -1 , 2 , 0 , 1 , -3 , -3 , 1 , 2 , -3 , -3 ,
1 , -3 , 2 , 0 , -3 , 1 , 2 , 2 , -2 , -1 , -2 , 1 , 1 , 0 , -2 , -2 , 1 , 2 , -1 ,
-3 , 1 , -2 , 2 , -3 , -2 , -3 , 2 , 1 , 0 , -2 , 0 , 1 , -3 , 2 , -2 , -2 , 0 , 2 ,
-3 , 2 , 0 , 0 , 1 , -2 , 1 , 1 , -2 , -1 , -2 , 1 , -2 , 0 , -2 , -2 , 0 , -1 , -1 ,
-3 , -3 , -3 , 1 , -3 , -2 , 2 , -1 , 2 , 0 , 2 , -2 , 2 , -2 , 1 , -3 , -3 , -1 , 0 ,
2 , 2 , 1 , -1 , -3 , -1 , -3 , 2 , 1 , -2 , 0 , -3 , -1 , -3 , -1 , 2 , 1 , 0 , 2 ,
-1 , 1 , 0 , 1 , 2 , -1 , -2 , 2 , 1 , -3 , -1 , -3 , 0 , 1 , -2 , 0 , -2 , -3 , 0 ,
-2 , 2 , 2 , 0 , 0 , 2 , -3 , 2 , -3 , -2 , 1 , 2 , -3 , -3 , -1 , -3 , 0 , -3 , -3 ,
-2 , -2 , -2 , 0 , 0 , 1 , 0 , 0 , -1 , 0 , 0 , -3 , 0 , -3 , -1 , -2 , 1 , -2 , -1 ,
2 , -2 , 0 , 0 , 1 , 0 , -2 , -1 , 0 , -3 , 1 , 0 , -1 , -3 , 1 , -1 , 1 , -1 , -3 ,
1 , 0 , 1 , 1 , -1 , 2 , 2 , 0 , 0 , 1 , -3 , 2 , -2 , -2 , -3 , -2 , -1 , -2 , 2 ,
0 , 2 , -2 , -3 , -1 , -3 , 2 , 2 , -1 , 2 , 2 , -1 , 0 , -3 , 1 ,
};
float bias[] = {
1 , -1 , 0 , 1 , 1 , 1 , -2 ,
};
float expected_same[] = {
-1125 , 2926 , 6406 , 631 , -1244 , 97 , -1454 , 2526 , 1065 , 3292 , 3464 ,
2553 , -330 , 532 , 1038 , 1182 , -402 , 3758 , 3392 , 9854 , 4365 , 1408 ,
4736 , 3134 , 3838 , 2409 , 3221 , 4350 , 6750 , 4045 , 815 , 1188 , 2959 ,
9802 , 9590 , 4572 , 5740 , 4253 , 1701 , 7974 , 7012 , 6854 , 7093 , 3907 ,
4539 , 3886 , 4267 , 3505 , 465 , 7824 , 9219 , 10026 , 7968 , 957 , 2295 ,
5594 , 10811 , 9641 , 5950 , 10043 , 8783 , 3132 , 1421 , 1110 , 4108 , 13929 ,
10660 , -84 , -61 , 3932 , -180 , 6811 , 13393 , 15147 , 15640 , 9337 , 6961 ,
3808 , 1604 , 1398 , 1047 , 6739 , 10144 , 6517 , 4698 , 2678 , 7389 , 2595 ,
5248 , 12075 , 11272 , 13951 , 8820 , 1090 , 2199 , 2206 , 2788 , 12116 , 6683 ,
2612 , -291 , 3183 , 9414 , 12316 , 14524 , 12333 , 13208 , 7832 , 4664 , 4657 ,
3534 , 1298 , -666 , 4250 , 7707 , 9103 , 5760 , 688 , 9571 , 15782 , 14203 ,
14878 , 17339 , 14684 , 8690 , 5671 , 875 , 1429 , 1531 , 6173 , 2984 , 5558 ,
2996 , 7928 , 6733 , 16117 , 15262 , 12757 , 7980 , 3923 , 4795 , 5973 , 2051 ,
455 , -1922 , 1816 , 5906 , 3321 , 10908 , 10910 , 7377 , 12204 , 12809 , 11195 ,
7451 , 6666 , 74 , -1645 , -35 , -391 , 3813 , 7324 , 892 , 1656 , 6095 ,
12193 , 14648 , 12156 , 14663 , 10251 , 10325 , 7821 , 3925 , 323 , 697 , 442 ,
1324 , 4669 , 7002 , 5485 , 5171 , 5086 , 10582 , 11053 , 9709 , 11353 , 8543 ,
5256 , 2873 , 235 , -628 , 1496 , 1878 , -867 , 3420 , 6865 , 5937 , 10182 ,
13277 , 10069 , 10789 , 5998 , 624 , -2082 , 4417 , 1258 , -1080 , -819 , -1430 ,
1033 , 5220 , 6335 , 8471 , 8980 , 11908 , 14430 , 12584 , 8404 , 1576 , -803 ,
985 , 1481 , 1367 , -193 , 873 , 3684 , 2288 , 6676 , 9477 , 11155 , 9602 ,
9707 , 10507 , 4739 , 3174 , -575 , -178 , 3002 , 1710 , 423 , -477 , 554 ,
3088 , 2029 , 5113 , 5000 , 3771 , 6090 , 5365 , 1185 , 2855 , 399 , -312 ,
-1577 , 176 , 955 ,
};
float expected_replicate[] = {
13768 , 13528 , 12999 , 6906 , 4618 , 4043 , 2611 , 9955 , 6685 , 4776 , 2753 ,
1036 , 3063 , 4544 , 5183 , 7349 , 12451 , 12501 , 9131 , 12753 , 8908 , 4058 ,
6299 , 7542 , 7115 , 3307 , 3360 , 3543 , 9754 , 7808 , 5991 , 9019 , 14320 ,
14919 , 12492 , 6871 , 7373 , 3336 , 2085 , 10604 , 9377 , 6882 , 5009 , 3103 ,
6220 , 6278 , 7588 , 10196 , 11045 , 11563 , 11842 , 11911 , 8279 , 2030 , 1858 ,
6368 , 12123 , 9909 , 6347 , 10345 , 9365 , 4038 , 1673 , 3051 , 16492 , 16649 ,
12276 , 408 , -301 , 4122 , -654 , 7864 , 14038 , 15279 , 15315 , 9744 , 8243 ,
5298 , 746 , 380 , 9824 , 9124 , 10895 , 6640 , 4712 , 2669 , 6980 , 2759 ,
5385 , 12345 , 11336 , 13129 , 8600 , 2370 , 3682 , 5219 , 12407 , 13123 , 6784 ,
2612 , -291 , 3183 , 9414 , 12316 , 14524 , 12333 , 13397 , 7543 , 3916 , 4153 ,
4477 , 4314 , 7983 , 8418 , 9163 , 9103 , 5760 , 688 , 9571 , 15782 , 14203 ,
14878 , 17718 , 14570 , 7940 , 6642 , 5094 , 7133 , 9964 , 10219 , 3224 , 5558 ,
2996 , 7928 , 6733 , 16117 , 15262 , 12757 , 7958 , 4401 , 5187 , 5476 , 5529 ,
6055 , 2206 , 3909 , 6015 , 3321 , 10908 , 10910 , 7377 , 12204 , 12809 , 11195 ,
6967 , 6840 , 481 , -1600 , 274 , 1 , 10373 , 8514 , 1123 , 2117 , 6758 ,
12736 , 16223 , 13585 , 15988 , 11771 , 10600 , 7918 , 4156 , 2840 , 3111 , 3287 ,
6359 , 7652 , 8813 , 6530 , 6967 , 7789 , 13671 , 13990 , 13247 , 13241 , 9836 ,
5251 , 3024 , 2313 , 1834 , 4187 , 2637 , -1312 , 2139 , 7378 , 7665 , 11933 ,
15591 , 15314 , 15678 , 9531 , 2820 , -1516 , 3400 , 1314 , 22 , 363 , -2896 ,
-898 , 5906 , 7308 , 10650 , 12975 , 16978 , 20370 , 18817 , 12381 , 4118 , -861 ,
-137 , 236 , 1802 , 1632 , -350 , 2334 , 3400 , 8680 , 14064 , 18216 , 18675 ,
21765 , 22871 , 11491 , 4937 , -1555 , -11 , 1669 , 2392 , 3265 , -5254 , -217 ,
5001 , 8063 , 13444 , 18884 , 19706 , 22794 , 21064 , 9545 , 6689 , -7 , 289 ,
-2021 , 504 , 2347 ,
};
float expected_valid[] = {
2612 , -291 , 3183 , 9414 , 12316 , 14524 , 12333 , 9103 , 5760 , 688 ,
9571 , 15782 , 14203 , 14878 , 5558 , 2996 , 7928 , 6733 , 16117 , 15262 ,
12757 , 3321 , 10908 , 10910 , 7377 , 12204 , 12809 , 11195 ,
};
CNN_CONFIG cnn_config = { 3 ,
0 ,
0 ,
0 ,
0 ,
{
{
1 ,
filter_width,
filter_height,
3 ,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
3 ,
filter_width,
filter_height,
3 ,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
3 ,
filter_width,
filter_height,
1 ,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
},
} };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
for (int i = 0 ; i < cnn_config.num_layers; ++i) {
cnn_config.layer_config[i].pad = PADDING_SAME_REPLICATE;
}
RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
for (int i = 0 ; i < cnn_config.num_layers; ++i) {
cnn_config.layer_config[i].pad = PADDING_VALID;
}
RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestRELUSingleLayer) {
int image_width = 8 ;
int image_height = 8 ;
int filter_height = 5 ;
int filter_width = 4 ;
float input[] = {
0 , -2 , -3 , 1 , -1 , 2 , -2 , 1 , -3 , -1 , 0 , 1 , -2 , -3 , -2 , -2 ,
1 , -3 , 2 , -3 , -1 , -1 , 2 , 0 , -2 , -3 , 0 , -2 , -3 , 1 , -1 , -1 ,
2 , -2 , 0 , -2 , -3 , -3 , 1 , 1 , -1 , 1 , 0 , 1 , -3 , 0 , 2 , 2 ,
0 , -3 , 1 , -3 , 2 , -2 , 1 , -1 , -1 , -2 , -3 , -2 , -1 , -3 , -2 , -1 ,
};
float expected_same[] = {
9 , 0 , 1 , 1 , 0 , 3 , 0 , 19 , 0 , 12 , 10 , 0 , 0 , 0 , 5 , 0 ,
0 , 18 , 21 , 7 , 19 , 4 , 3 , 0 , 0 , 9 , 16 , 0 , 11 , 16 , 0 , 11 ,
12 , 2 , 0 , 11 , 0 , 16 , 6 , 0 , 8 , 22 , 13 , 10 , 12 , 0 , 0 , 0 ,
0 , 1 , 2 , 12 , 29 , 6 , 10 , 0 , 13 , 0 , 0 , 5 , 8 , 10 , 0 , 0 ,
};
float expected_replicate[] = {
18 , 17 , 12 , 2 , 0 , 0 , 5 , 11 , 0 , 17 , 22 , 6 , 0 , 0 , 17 , 0 ,
0 , 18 , 21 , 7 , 19 , 4 , 3 , 5 , 3 , 9 , 16 , 0 , 11 , 16 , 0 , 3 ,
3 , 2 , 0 , 11 , 0 , 16 , 6 , 0 , 17 , 22 , 13 , 10 , 12 , 0 , 0 , 0 ,
0 , 4 , 1 , 10 , 30 , 7 , 10 , 0 , 23 , 8 , 0 , 13 , 15 , 19 , 8 , 10 ,
};
float expected_valid[] = {
18 , 21 , 7 , 19 , 4 , 9 , 16 , 0 , 11 , 16 , 2 , 0 , 11 , 0 , 16 , 22 , 13 , 10 , 12 , 0 ,
};
float weights[] = {
-2 , -3 , 1 , 2 , 2 , -2 , -3 , 0 , -3 , 2 , 2 , -3 , -3 , -2 , 0 , 1 , 2 , 0 , -1 , -1 ,
};
float bias[] = { -3 };
CNN_CONFIG cnn_config = { 1 ,
0 ,
0 ,
0 ,
0 ,
{ {
1 ,
filter_width,
filter_height,
1 ,
1 ,
1 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
RELU,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
} } };
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0 ].pad = PADDING_SAME_REPLICATE;
RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0 ].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestVaryingStridesVaryingDimImages) {
float weights[] = {
1 , -5 , -3 , -4 , -1 , 1 , 2 , -3 , 2 , 2 , -1 , 1 , -5 , 1 , 1 ,
-3 , -5 , 3 , 1 , 4 , -2 , -5 , -2 , -3 , -5 , 0 , -1 , -5 , 2 , -2 ,
-2 , 1 , -2 , -4 , 1 , 3 , -2 , 2 , 0 , -3 , 2 , -3 , -2 , -3 ,
};
float bias[] = { 2 };
CNN_CONFIG cnn_config = { 1 ,
0 ,
0 ,
0 ,
0 ,
{
{
1 ,
4 ,
11 ,
1 ,
7 ,
6 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
},
} };
int image_height = 24 ;
int image_width = 17 ;
float input[] = {
-1 , -3 , 4 , 4 , -5 , 4 , 3 , -5 , -1 , -3 , 4 , -4 , 2 , -3 , 3 , -5 , 2 , -1 , -5 ,
1 , -1 , 3 , 1 , -3 , -3 , 4 , 0 , 2 , -3 , -5 , -5 , -4 , 0 , -5 , -2 , -3 , -1 , -2 ,
2 , -5 , 4 , 4 , 0 , -4 , -3 , 1 , -3 , -5 , -4 , -4 , 1 , -2 , -3 , 3 , -3 , -3 , -1 ,
-5 , -5 , -2 , 3 , 1 , -1 , -5 , -5 , 1 , -4 , -2 , -1 , -2 , -4 , -4 , 2 , -2 , 2 , 1 ,
-2 , -4 , -1 , 1 , -2 , -5 , 3 , -2 , -1 , -1 , -5 , -3 , 1 , -2 , -2 , -3 , -1 , -2 , -4 ,
-2 , 1 , -4 , -1 , 4 , 3 , -4 , 0 , 4 , 2 , 2 , 4 , -3 , -5 , 2 , 2 , 1 , -1 , -4 ,
-2 , 1 , 3 , 2 , 0 , 4 , -1 , -3 , 2 , 1 , -4 , 2 , 2 , -4 , -2 , 0 , -2 , -1 , 4 ,
4 , 2 , 3 , -4 , 2 , -4 , -5 , 4 , -1 , -3 , -1 , 0 , -4 , 1 , 3 , -1 , -3 , -5 , 3 ,
-2 , -4 , 1 , 2 , -2 , -3 , -3 , -5 , 1 , -3 , -1 , 0 , -1 , 3 , -4 , -1 , -5 , -5 , 1 ,
0 , 0 , -2 , -2 , 2 , -2 , 0 , 0 , 2 , 0 , -3 , 0 , -1 , -4 , -4 , -1 , 3 , -4 , -4 ,
-1 , 0 , -5 , -3 , -2 , 4 , -3 , -4 , -4 , 0 , -5 , 1 , -2 , -3 , -3 , -4 , 4 , 3 , 4 ,
3 , 3 , -1 , 3 , 1 , -3 , -2 , 3 , 3 , 0 , 2 , -4 , -3 , 2 , 2 , 0 , -2 , 4 , -2 ,
2 , -2 , -1 , -4 , -2 , 2 , -4 , 3 , -1 , 4 , 1 , 1 , 4 , -1 , -4 , -4 , 1 , 1 , -2 ,
4 , -1 , 3 , 2 , -3 , 4 , 3 , 1 , 4 , 0 , -4 , 2 , 0 , 2 , 4 , -2 , -2 , 4 , 2 ,
-1 , -2 , 1 , -3 , 2 , 3 , -5 , -3 , 4 , 4 , 2 , -5 , -4 , -5 , -2 , -4 , 2 , 0 , 2 ,
-5 , 4 , -4 , -2 , -5 , 2 , 1 , 0 , 4 , 1 , -2 , -3 , -4 , -3 , -4 , 3 , 3 , 2 , 0 ,
-3 , 1 , -5 , 4 , 0 , 4 , -1 , 3 , -5 , -5 , -2 , -1 , -1 , 4 , 3 , 3 , 4 , 3 , -4 ,
4 , -3 , -3 , -1 , -4 , -1 , -4 , -1 , -2 , 4 , -2 , -4 , 4 , 4 , -3 , -4 , -1 , 1 , 2 ,
-1 , -2 , -2 , 3 , 2 , 2 , -3 , 0 , -1 , 0 , 3 , 2 , -5 , 0 , -4 , 0 , 0 , 2 , -4 ,
-1 , -1 , 0 , -2 , 0 , 1 , 0 , 0 , 4 , -5 , -1 , -5 , 2 , -1 , 0 , 2 , -1 , 1 , 3 ,
-3 , -5 , -2 , -3 , 4 , -2 , -2 , -1 , -3 , -4 , -1 , -2 , -4 , 1 , 4 , -3 , -2 , -1 , 3 ,
-3 , -2 , 3 , 2 , 1 , -4 , -3 , -5 , 1 ,
};
float expected_1[] = {
41 , -26 , 5 , 76 , 13 , 83 , -21 , 53 , -54 , -14 , 21 , 121 ,
};
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected_1, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0 ].skip_width = 6 ;
cnn_config.layer_config[0 ].skip_height = 7 ;
float expected_2[] = {
21 , -50 , 41 , 20 , 72 , 127 , -21 , 103 , 62 , -37 , 83 , -3 ,
};
RunCNNTest(image_width, image_height, input, expected_2, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0 ].skip_width = 3 ;
cnn_config.layer_config[0 ].skip_height = 10 ;
float expected_3[] = {
-26 , -21 , -35 , 69 , 49 , 4 , -51 , -43 , -56 ,
-41 , 15 , -44 , 40 , -62 , 63 , 38 , 27 , 47 ,
};
RunCNNTest(image_width, image_height, input, expected_3, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0 ].skip_width = 10 ;
cnn_config.layer_config[0 ].skip_height = 3 ;
float expected_4[] = {
21 , 49 , 28 , 87 , 50 , 40 , 102 , 81 , 58 , 85 , 51 , 66 , 36 , 19 , -37 , -45 ,
};
RunCNNTest(image_width, image_height, input, expected_4, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestMaxPool) {
int image_width = 8 ;
int image_height = 8 ;
int stride = 3 ;
float input[] = {
1 , -4 , -4 , 8 , 0 , 7 , -5 , -2 , 8 , 2 , 2 , 8 , 5 , -1 , -1 , 9 ,
-3 , 0 , -2 , 0 , 6 , 3 , -4 , 8 , 7 , 8 , 7 , -1 , 4 , -1 , 0 , 2 ,
-5 , -2 , 8 , 5 , 5 , 4 , 2 , 7 , 4 , 6 , 2 , 8 , 8 , -4 , -3 , -4 ,
-3 , -1 , 2 , 3 , 3 , 6 , -5 , 8 , 9 , 5 , 0 , -2 , -1 , 6 , 5 , 7 ,
};
float expected[] = {
49 , 58 , 70 , 68 , 68 , 70 , 48 , 57 , 88 ,
};
float weights[] = {
3 , 1 , 3 , 4 , -1 , 5 , -2 , 1 , -4 ,
};
float bias[] = {
-3 ,
};
CNN_CONFIG cnn_config = { 1 ,
0 ,
0 ,
0 ,
0 ,
{ {
1 ,
3 ,
3 ,
1 ,
stride,
stride,
1 ,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
} } };
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestDeconvolveNonActivationSingleLayerSingleKernel) {
int image_width = 4 ;
int image_height = 7 ;
float input[] = {
9 , 6 , 181 , 9 , 218 , 30 , 80 , 108 , 68 , 216 , 70 , 128 , 179 , 228 ,
33 , 212 , 34 , 14 , 48 , 27 , 230 , 23 , 202 , 113 , 80 , 56 , 122 , 112 ,
};
float expected_1_same[] = {
15 , -30 , 36 , -525 , 377 , -193 , 558 , 531 , 6 , -24 , -15 , 124 ,
166 , -561 , -356 , -754 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 ,
433 , -311 , 711 , 381 , 247 , -317 , 453 , 129 , 215 , -627 , -409 , -885 ,
17 , -255 , -55 , -647 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 ,
133 , -719 , 633 , -225 , 785 , 191 , 463 , 79 , 65 , 9 , 77 , -853 ,
-365 , -949 , -15 , -667 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 ,
355 , -866 , 990 , 207 , 747 , 12 , 520 , -116 , 176 , -312 , -133 , -1370 ,
-426 , -802 , 143 , -771 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 ,
65 , -79 , 127 , -59 , 135 , -90 , 195 , 114 , 31 , -91 , -57 , -133 ,
17 , -176 , -72 , -276 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 ,
457 , -302 , 733 , 58 , 470 , -475 , 829 , 490 , 227 , -670 , -440 , -790 ,
153 , -588 , -294 , -1150 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 ,
157 , -251 , 349 , -185 , 409 , -293 , 587 , 251 , 77 , -187 , -107 , -369 ,
7 , -481 , -135 , -827 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 ,
};
float expected_1_valid[] = {
-30 , 15 , -30 , 36 , -525 , 377 , -193 , 558 , 531 , 24 , 24 , 6 ,
6 , -24 , -15 , 124 , 166 , -561 , -356 , -754 , -21 , -39 , -3 , -3 ,
-3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -657 , 433 , -311 ,
711 , 381 , 247 , -317 , 453 , 129 , 321 , 321 , 215 , 215 , -627 , -409 ,
-885 , 17 , -255 , -55 , -647 , -219 , -435 , -3 , -3 , -3 , -3 , -3 ,
-3 , -3 , -3 , -3 , -3 , -3 , -207 , 133 , -719 , 633 , -225 , 785 ,
191 , 463 , 79 , 381 , 381 , 65 , 65 , 9 , 77 , -853 , -365 , -949 ,
-15 , -667 , -259 , -515 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 ,
-3 , -3 , -3 , -540 , 355 , -866 , 990 , 207 , 747 , 12 , 520 , -116 ,
633 , 633 , 176 , 176 , -312 , -133 , -1370 , -426 , -802 , 143 , -771 , -427 ,
-851 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 ,
-105 , 65 , -79 , 127 , -59 , 135 , -90 , 195 , 114 , 78 , 78 , 31 ,
31 , -91 , -57 , -133 , 17 , -176 , -72 , -276 , -57 , -111 , -3 , -3 ,
-3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -3 , -693 , 457 , -302 ,
733 , 58 , 470 , -475 , 829 , 490 , 336 , 336 , 227 , 227 , -670 , -440 ,
-790 , 153 , -588 , -294 , -1150 , -229 , -455 , -3 , -3 , -3 , -3 , -3 ,
-3 , -3 , -3 , -3 , -3 , -3 , -243 , 157 , -251 , 349 , -185 , 409 ,
-293 , 587 , 251 , 333 , 333 , 77 , 77 , -187 , -107 , -369 , 7 , -481 ,
-135 , -827 , -227 , -451 ,
};
float weights_1[] = { -3 , 2 , -1 , 3 , 3 , 1 , 1 , -3 , -2 , -4 };
float bias_1[] = { -3 };
CNN_CONFIG cnn_config = { 1 ,
0 ,
0 ,
0 ,
0 ,
{ {
1 ,
5 ,
2 ,
1 ,
2 ,
3 ,
0 ,
weights_1,
bias_1,
PADDING_SAME_ZERO,
NONE,
1 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
} } };
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected_1_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
// Change padding to valid
cnn_config.layer_config[0 ].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_1_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
float expected_12_same[] = {
15 , -12 , 6 , 36 , -9 , -528 , 377 , -184 , 513 , 558 , -12 , 24 ,
6 , -30 , -15 , -33 , -21 , 166 , 154 , -546 , -356 , -718 , -30 , -21 ,
433 , -221 , 561 , 711 , -33 , -153 , 247 , -83 , -87 , 453 , -111 , 321 ,
215 , -657 , -409 , -845 , -93 , 17 , -43 , -243 , -55 , -215 , -327 , -219 ,
133 , -71 , -447 , 633 , -219 , 435 , 785 , -73 , -177 , 463 , -131 , 381 ,
65 , -207 , 77 , -59 , -651 , -365 , -797 , -213 , -15 , -155 , -387 , -259 ,
355 , -182 , -150 , 990 , -231 , 582 , 747 , -36 , -540 , 520 , -215 , 633 ,
176 , -540 , -133 , -491 , -687 , -426 , -882 , -102 , 143 , 77 , -639 , -427 ,
65 , -37 , 57 , 127 , -17 , -105 , 135 , -51 , 60 , 195 , -30 , 78 ,
31 , -105 , -57 , -125 , -45 , 17 , -11 , -147 , -72 , -168 , -84 , -57 ,
457 , -233 , 618 , 733 , -26 , -540 , 470 , -205 , 264 , 829 , -116 , 336 ,
227 , -693 , -440 , -900 , -72 , 153 , 107 , -609 , -294 , -698 , -342 , -229 ,
157 , -83 , 69 , 349 , -59 , -201 , 409 , -125 , 27 , 587 , -115 , 333 ,
77 , -243 , -107 , -267 , -171 , 7 , -105 , -369 , -135 , -379 , -339 , -227 ,
};
float expected_12_valid[] = {
-30 , 15 , -12 , 6 , 36 , -9 , -528 , 377 , -184 , 513 , 558 , -12 ,
24 , 24 , 6 , 6 , -30 , -15 , -33 , -21 , 166 , 154 , -546 , -356 ,
-718 , -30 , -21 , -39 , -657 , 433 , -221 , 561 , 711 , -33 , -153 , 247 ,
-83 , -87 , 453 , -111 , 321 , 321 , 215 , 215 , -657 , -409 , -845 , -93 ,
17 , -43 , -243 , -55 , -215 , -327 , -219 , -435 , -207 , 133 , -71 , -447 ,
633 , -219 , 435 , 785 , -73 , -177 , 463 , -131 , 381 , 381 , 65 , 65 ,
-207 , 77 , -59 , -651 , -365 , -797 , -213 , -15 , -155 , -387 , -259 , -515 ,
-540 , 355 , -182 , -150 , 990 , -231 , 582 , 747 , -36 , -540 , 520 , -215 ,
633 , 633 , 176 , 176 , -540 , -133 , -491 , -687 , -426 , -882 , -102 , 143 ,
77 , -639 , -427 , -851 , -105 , 65 , -37 , 57 , 127 , -17 , -105 , 135 ,
-51 , 60 , 195 , -30 , 78 , 78 , 31 , 31 , -105 , -57 , -125 , -45 ,
17 , -11 , -147 , -72 , -168 , -84 , -57 , -111 , -693 , 457 , -233 , 618 ,
733 , -26 , -540 , 470 , -205 , 264 , 829 , -116 , 336 , 336 , 227 , 227 ,
-693 , -440 , -900 , -72 , 153 , 107 , -609 , -294 , -698 , -342 , -229 , -455 ,
-243 , 157 , -83 , 69 , 349 , -59 , -201 , 409 , -125 , 27 , 587 , -115 ,
333 , 333 , 77 , 77 , -243 , -107 , -267 , -171 , 7 , -105 , -369 , -135 ,
-379 , -339 , -227 , -451 ,
};
// Change skip_width, skip_height to {2, 3}
cnn_config.layer_config[0 ].skip_width = 3 ;
cnn_config.layer_config[0 ].skip_height = 2 ;
// Set padding to same
cnn_config.layer_config[0 ].pad = PADDING_SAME_ZERO;
RunCNNTest(image_width, image_height, input, expected_12_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
// Change padding to valid
cnn_config.layer_config[0 ].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_12_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0 ].filter_width = 4 ;
cnn_config.layer_config[0 ].filter_height = 3 ;
float weights_2[] = { -1 , -3 , -1 , -3 , 0 , 2 , -2 , 4 , 3 , 0 , 1 , 4 };
float bias_2[] = { -4 };
cnn_config.layer_config[0 ].weights = weights_2;
cnn_config.layer_config[0 ].bias = bias_2;
cnn_config.layer_config[0 ].skip_width = 5 ;
cnn_config.layer_config[0 ].skip_height = 2 ;
float expected_2_same[] = {
-13 , -31 , -13 , -31 , -4 , -10 , -22 , -10 , -22 , -4 , -185 , -547 ,
-185 , -547 , -4 , -13 , -31 , -13 , -31 , -4 , -4 , 14 , -22 , 32 ,
-4 , -4 , 8 , -16 , 20 , -4 , -4 , 358 , -366 , 720 , -4 , -4 ,
14 , -22 , 32 , -4 , -195 , -658 , -213 , -622 , -4 , -16 , -94 , -28 ,
-70 , -4 , 459 , -244 , 97 , 480 , -4 , -85 , -328 , -103 , -292 , -4 ,
-4 , 432 , -440 , 868 , -4 , -4 , 56 , -64 , 116 , -4 , -4 , 156 ,
-164 , 316 , -4 , -4 , 212 , -220 , 428 , -4 , 582 , -208 , 146 , 664 ,
-4 , -130 , -652 , -190 , -532 , -4 , 166 , -214 , 6 , 106 , -4 , 192 ,
-388 , -24 , 44 , -4 , -4 , 132 , -140 , 268 , -4 , -4 , 428 , -436 ,
860 , -4 , -4 , 136 , -144 , 276 , -4 , -4 , 252 , -260 , 508 , -4 ,
21 , -541 , -115 , -269 , -4 , 416 , -688 , -16 , 176 , -4 , 173 , -103 ,
33 , 177 , -4 , 168 , -640 , -88 , -128 , -4 , -4 , 354 , -362 , 712 ,
-4 , -4 , 452 , -460 , 908 , -4 , -4 , 62 , -70 , 128 , -4 , -4 ,
420 , -428 , 844 , -4 , 499 , -106 , 141 , 610 , -4 , 666 , -46 , 210 ,
866 , -4 , 47 , -148 , -19 , -16 , -4 , 605 , -85 , 181 , 763 , -4 ,
-4 , 64 , -72 , 132 , -4 , -4 , 24 , -32 , 52 , -4 , -4 , 92 ,
-100 , 188 , -4 , -4 , 50 , -58 , 104 , -4 , -132 , -694 , -200 , -558 ,
-4 , 15 , -73 , -13 , -17 , -4 , -62 , -610 , -158 , -418 , -4 , -36 ,
-343 , -90 , -235 , -4 , -4 , 456 , -464 , 916 , -4 , -4 , 42 , -50 ,
88 , -4 , -4 , 400 , -408 , 804 , -4 , -4 , 222 , -230 , 448 , -4 ,
606 , -244 , 146 , 676 , -4 , 9 , -172 , -37 , -80 , -4 , 480 , -370 ,
76 , 438 , -4 , 223 , -340 , -3 , 112 , -4 , -4 , 156 , -164 , 316 ,
-4 , -4 , 108 , -116 , 220 , -4 , -4 , 240 , -248 , 484 , -4 , -4 ,
220 , -228 , 444 , -4 ,
};
float expected_2_valid[] = {
-13 , -31 , -13 , -31 , -4 , -10 , -22 , -10 , -22 , -4 , -185 , -547 ,
-185 , -547 , -4 , -13 , -31 , -13 , -31 , -4 , 14 , -22 , 32 , -4 ,
-4 , 8 , -16 , 20 , -4 , -4 , 358 , -366 , 720 , -4 , -4 , 14 ,
-22 , 32 , -195 , -658 , -213 , -622 , -4 , -16 , -94 , -28 , -70 , -4 ,
459 , -244 , 97 , 480 , -4 , -85 , -328 , -103 , -292 , -4 , 432 , -440 ,
868 , -4 , -4 , 56 , -64 , 116 , -4 , -4 , 156 , -164 , 316 , -4 ,
-4 , 212 , -220 , 428 , 582 , -208 , 146 , 664 , -4 , -130 , -652 , -190 ,
-532 , -4 , 166 , -214 , 6 , 106 , -4 , 192 , -388 , -24 , 44 , -4 ,
132 , -140 , 268 , -4 , -4 , 428 , -436 , 860 , -4 , -4 , 136 , -144 ,
276 , -4 , -4 , 252 , -260 , 508 , 21 , -541 , -115 , -269 , -4 , 416 ,
-688 , -16 , 176 , -4 , 173 , -103 , 33 , 177 , -4 , 168 , -640 , -88 ,
-128 , -4 , 354 , -362 , 712 , -4 , -4 , 452 , -460 , 908 , -4 , -4 ,
62 , -70 , 128 , -4 , -4 , 420 , -428 , 844 , 499 , -106 , 141 , 610 ,
-4 , 666 , -46 , 210 , 866 , -4 , 47 , -148 , -19 , -16 , -4 , 605 ,
-85 , 181 , 763 , -4 , 64 , -72 , 132 , -4 , -4 , 24 , -32 , 52 ,
-4 , -4 , 92 , -100 , 188 , -4 , -4 , 50 , -58 , 104 , -132 , -694 ,
-200 , -558 , -4 , 15 , -73 , -13 , -17 , -4 , -62 , -610 , -158 , -418 ,
-4 , -36 , -343 , -90 , -235 , -4 , 456 , -464 , 916 , -4 , -4 , 42 ,
-50 , 88 , -4 , -4 , 400 , -408 , 804 , -4 , -4 , 222 , -230 , 448 ,
606 , -244 , 146 , 676 , -4 , 9 , -172 , -37 , -80 , -4 , 480 , -370 ,
76 , 438 , -4 , 223 , -340 , -3 , 112 , -4 , 156 , -164 , 316 , -4 ,
-4 , 108 , -116 , 220 , -4 , -4 , 240 , -248 , 484 , -4 , -4 , 220 ,
-228 , 444 , 236 , -4 , 76 , 316 , -4 , 164 , -4 , 52 , 220 , -4 ,
362 , -4 , 118 , 484 , -4 , 332 , -4 , 108 , 444 ,
};
// Set padding to same
cnn_config.layer_config[0 ].pad = PADDING_SAME_ZERO;
RunCNNTest(image_width, image_height, input, expected_2_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0 ].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_2_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0 ].skip_width = 2 ;
cnn_config.layer_config[0 ].skip_height = 5 ;
float expected_21_same[] = {
-31 , -19 , -49 , -191 , -565 , -194 , -574 , -13 , 14 , -22 , 44 , -16 ,
382 , -366 , 738 , -22 , -4 , 23 , 32 , 545 , 20 , 204 , 720 , 5 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -658 , -252 , -748 , -114 , -334 , -192 , -568 , -112 ,
432 , -440 , 928 , -64 , 276 , -164 , 532 , -220 , -4 , 304 , 868 , 266 ,
116 , 400 , 316 , 104 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -208 , -288 , -856 , -290 ,
-862 , -202 , -598 , -132 , 132 , -140 , 700 , -436 , 1000 , -144 , 532 , -260 ,
-4 , 712 , 268 , 422 , 860 , 450 , 276 , 124 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-541 , -411 , -1225 , -265 , -787 , -249 , -739 , -216 , 354 , -362 , 1168 , -460 ,
974 , -70 , 552 , -428 , -4 , 859 , 712 , 323 , 908 , 665 , 128 , 208 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -106 , -52 , -148 , -66 , -190 , -79 , -229 , -31 ,
64 , -72 , 160 , -32 , 148 , -100 , 242 , -58 , -4 , 72 , 132 , 154 ,
52 , 125 , 188 , 23 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -694 , -257 , -763 , -229 ,
-679 , -319 , -949 , -117 , 456 , -464 , 962 , -50 , 492 , -408 , 1030 , -230 ,
-4 , 295 , 916 , 625 , 88 , 537 , 804 , 109 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-244 , -140 , -412 , -182 , -538 , -238 , -706 , -116 , 156 , -164 , 428 , -116 ,
464 , -248 , 708 , -228 , -4 , 244 , 316 , 418 , 220 , 454 , 484 , 108 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 ,
};
float expected_21_valid[] = {
-13 , -31 , -19 , -49 , -191 , -565 , -194 , -574 , -13 , -31 , -4 , 14 ,
-22 , 44 , -16 , 382 , -366 , 738 , -22 , 32 , 23 , -4 , 23 , 32 ,
545 , 20 , 204 , 720 , 5 , 32 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -222 , -658 , -252 , -748 , -114 , -334 , -192 , -568 , -112 , -328 ,
-4 , 432 , -440 , 928 , -64 , 276 , -164 , 532 , -220 , 428 , 650 , -4 ,
304 , 868 , 266 , 116 , 400 , 316 , 104 , 428 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -72 , -208 , -288 , -856 , -290 , -862 , -202 , -598 ,
-132 , -388 , -4 , 132 , -140 , 700 , -436 , 1000 , -144 , 532 , -260 , 508 ,
200 , -4 , 712 , 268 , 422 , 860 , 450 , 276 , 124 , 508 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -183 , -541 , -411 , -1225 , -265 , -787 ,
-249 , -739 , -216 , -640 , -4 , 354 , -362 , 1168 , -460 , 974 , -70 , 552 ,
-428 , 844 , 533 , -4 , 859 , 712 , 323 , 908 , 665 , 128 , 208 , 844 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -38 , -106 , -52 , -148 ,
-66 , -190 , -79 , -229 , -31 , -85 , -4 , 64 , -72 , 160 , -32 , 148 ,
-100 , 242 , -58 , 104 , 98 , -4 , 72 , 132 , 154 , 52 , 125 , 188 ,
23 , 104 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -234 , -694 ,
-257 , -763 , -229 , -679 , -319 , -949 , -117 , -343 , -4 , 456 , -464 , 962 ,
-50 , 492 , -408 , 1030 , -230 , 448 , 686 , -4 , 295 , 916 , 625 , 88 ,
537 , 804 , 109 , 448 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 , -4 ,
-84 , -244 , -140 , -412 , -182 , -538 , -238 , -706 , -116 , -340 , -4 , 156 ,
-164 , 428 , -116 , 464 , -248 , 708 , -228 , 444 , 236 , -4 , 244 , 316 ,
418 , 220 , 454 , 484 , 108 , 444 ,
};
cnn_config.layer_config[0 ].pad = PADDING_SAME_ZERO;
RunCNNTest(image_width, image_height, input, expected_21_same, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
cnn_config.layer_config[0 ].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_21_valid, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestLargeKernelsAndStrides) {
float input_10x11[] = {
4 , 4 , 2 , 4 , 2 , -5 , -2 , 3 , -1 , 0 , 0 , 1 , 2 , 0 , -5 , -2 , -5 , 1 , -3 ,
-1 , 4 , -3 , 2 , -2 , 1 , 0 , 1 , -3 , -3 , -4 , -2 , -2 , 1 , -4 , -1 , 4 , 1 , -4 ,
-4 , -4 , 3 , 2 , -5 , 3 , -5 , 1 , 2 , -4 , 1 , -1 , 3 , 4 , -2 , 3 , -3 , 3 , 0 ,
2 , -4 , -5 , -5 , -2 , -1 , -2 , 1 , 1 , 1 , -2 , 4 , -5 , 4 , -1 , -1 , 2 , 3 , -4 ,
2 , 2 , 3 , 0 , 0 , 1 , 0 , 3 , 2 , 3 , 1 , -2 , 3 , -4 , 3 , 2 , 4 , -2 , 0 ,
4 , -4 , 1 , -3 , -3 , -3 , -5 , 1 , -3 , -5 , 0 , 4 , -1 , -3 , 2 ,
};
float weights_10x11[] = {
-3 , 4 , -4 , -3 , -5 , 1 , -2 , 3 , 1 , -4 , -4 , 0 , -1 , 0 , 3 , 1 , -3 , -2 , 0 ,
-1 , 1 , 3 , -4 , -4 , -3 , -3 , -2 , 4 , 3 , -5 , 4 , 2 , -3 , 4 , -2 , -1 , 2 , -1 ,
-5 , 0 , -3 , 0 , 3 , -5 , -5 , 3 , -4 , -1 , -5 , 3 , 4 , 0 , 4 , -5 , 2 , -1 , 2 ,
-1 , -1 , -1 , -5 , 0 , -4 , 3 , -1 , 1 , 1 , -1 , 3 , 2 , -5 , -4 , 0 , -4 , 4 , -5 ,
-3 , 4 , -5 , 2 , -5 , -4 , -4 , -1 , 3 , 3 , 0 , 2 , -4 , 1 , -2 , 1 , 1 , 0 , 3 ,
-2 , 0 , 1 , 2 , 4 , -3 , -1 , -5 , -5 , 2 , -4 , 1 , 1 , 2 , -4 , -2 , -2 , 2 , 1 ,
3 , 4 , -5 , 1 , -1 , -3 , -3 , -1 , -2 , -5 , 1 , -1 , 0 , 1 , 4 , 4 , 0 , 0 , 4 ,
-3 , -1 , -5 , -3 , 0 , 1 , 1 , 1 , -5 , 3 , 4 , 3 , -5 , 3 , -2 , -2 , 0 , -4 , 0 ,
0 , -2 , 1 , -4 , -1 , 0 , -5 , -2 , -2 , -5 , -3 , -3 , 1 , 1 , -3 , 2 , 4 , 2 , 4 ,
-4 , -3 , 3 , 1 , 1 , 3 , -4 , 4 , -2 , -3 , -3 , -3 , -3 , -4 , -2 , 3 , -5 , 2 , 4 ,
-1 , -4 , -4 , 4 , -2 , -1 , 3 , -3 , -4 , -4 , -2 , 4 , 1 , 0 , 2 , -1 , 4 , -3 , 1 ,
4 , -3 , 4 , 4 , 0 , -4 , 3 , -2 , -3 , 2 , 3 , -1 , -3 , 2 , 1 , 4 , -2 , -3 , 1 ,
4 , -2 , 2 , -2 , -5 , -2 , 1 , 4 , -1 , -4 , 4 , -5 , 2 , -5 , -4 , -1 , -2 , 3 , 1 ,
2 , 1 , -5 , 1 , -5 , -4 , -1 , -2 , 2 , -2 , -4 , -3 , -2 , -2 , 4 , -1 , 2 , 2 , -4 ,
2 , -2 , 4 , -4 , -2 , -2 , 1 , -1 , 1 , 1 , 1 , -4 , -5 , -2 , 3 , -4 , -1 , 3 , -2 ,
3 , 2 , -5 , -4 , 0 , 3 , -2 , -4 , -5 , 3 , -2 , -4 , 2 , -2 , 1 , -4 , 0 , 2 , -5 ,
1 , -4 , -1 , -1 , 4 , -5 , -4 , 0 , -5 , -4 , -3 , -5 , -4 , 0 , 2 , 0 , -4 , 2 , -2 ,
1 , 1 , -3 , 2 , 0 , -4 , 0 , -4 , 1 , 0 , -5 , -1 , -1 , -1 , -5 , 4 , 2 , 2 , -4 ,
3 , -2 , -2 , 2 , -3 , -2 , -1 , 2 , -4 , -5 , 2 , -2 , -4 , -5 , -5 , -1 , 2 , -1 , 0 ,
-5 , -2 , -2 , -5 , 0 , 1 , -1 , -5 , 0 , 3 , 2 , 3 , 0 , -3 , -2 , 0 , -5 , -1 , -2 ,
2 , -4 , -1 , 2 , 2 , -5 , 2 , -4 , 0 , 3 , -3 , 1 , 0 , 0 , 1 , -5 , -3 , 1 , -1 ,
0 , -4 , -3 , 2 , -4 , -4 , 4 , -1 , 0 , 1 , 2 , -4 , -5 , 4 , -2 , 1 , -4 , -4 , -3 ,
-1 , -1 , 1 , -1 , -4 , -1 , -4 , -3 , 2 , -1 , -2 , -4 , 1 , 1 , 0 , -2 , 0 , -4 , 3 ,
-3 , 0 , -4 , -1 , -4 , 2 , -1 , -2 , -5 , -1 , -2 , -3 , 3 , -1 , 0 , -3 , 0 , 1 , -5 ,
1 , -5 , 0 , 1 ,
};
float bias_10x11[] = { 3 };
float expected_10x11[] = {
118 ,
};
CNN_CONFIG cnn_config = { 1 ,
0 ,
0 ,
0 ,
0 ,
{ {
1 ,
23 ,
20 ,
1 ,
15 ,
20 ,
0 ,
weights_10x11,
bias_10x11,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
} } };
int image_height = 10 ;
int image_width = 11 ;
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input_10x11, expected_10x11,
&cnn_config, image_width, &thread_data, MSE_INT_TOL);
float input_11x10[] = {
-2 , -2 , 3 , -5 , -1 , -3 , 1 , 3 , 2 , 1 , 1 , -5 , 4 , 1 , 3 , -5 , 3 , -3 , -5 ,
0 , -1 , -3 , -3 , 1 , 1 , -5 , -1 , -5 , -5 , -3 , 0 , 1 , -3 , -1 , -3 , -3 , 0 , 3 ,
4 , -4 , -1 , 3 , -3 , -1 , -3 , 1 , -3 , -2 , -1 , -4 , -3 , 2 , -4 , 1 , -4 , -1 , -3 ,
-5 , -1 , 2 , 3 , 0 , 2 , 2 , -5 , 4 , 1 , 2 , -1 , -4 , 4 , -4 , -4 , 0 , -1 , 1 ,
-1 , 1 , -3 , -3 , -2 , 1 , 2 , 4 , 4 , 4 , -3 , -3 , 0 , 1 , 0 , 1 , 4 , 1 , 3 ,
4 , -3 , -2 , -4 , 4 , 2 , 0 , 3 , 4 , -1 , 2 , -2 , 1 , -3 , -2 ,
};
float weights_11x10[] = {
4 , -1 , 1 , -1 , 2 , 4 , 3 , 3 , -4 , 3 , -5 , 1 , -1 , -1 , -2 , -2 , 0 , 2 , -3 ,
-2 , 3 , -5 , -1 , 0 , -1 , -2 , -2 , -1 , 2 , 4 , 3 , 1 , 0 , 0 , -3 , 3 , -4 , -1 ,
-5 , 4 , -2 , -2 , 1 , 2 , -1 , -3 , 1 , 2 , -5 , 1 , -3 , 3 , 3 , 0 , -4 , -4 , -5 ,
-3 , -4 , -4 , 4 , -2 , 4 , 4 , -2 , 2 , -5 , -1 , -2 , -5 , -1 , 4 , -3 , 3 , -2 , 0 ,
-4 , -3 , 0 , -1 , -2 , 4 , 2 , 0 , -2 , -5 , -4 , 1 , 4 , -4 , -2 , 2 , -2 , 1 , 1 ,
-4 , 1 , -4 , -4 , -2 , 4 , 2 , -1 , -5 , -5 , 1 , -3 , -3 , 3 , -3 , -5 , -3 , 4 , -1 ,
-1 , -3 , 0 , -4 , 3 , -1 , 0 , -2 , 0 , -5 , -2 , -5 , 2 , 0 , -5 , 2 , 3 , -2 , 2 ,
4 , -1 , 1 , -3 , 2 , 3 , 2 , 0 , -5 , -4 , -5 , 2 , 1 , 1 , -1 , -2 , 3 , 4 , 2 ,
-2 , 4 , -2 , 3 , 1 , -4 , -3 , -1 , 4 , 4 , -3 , -5 , -2 , 2 , 0 , 3 , -2 , 3 , -1 ,
-4 , 0 , -2 , 0 , 3 , 4 , -2 , -3 , -2 , 0 , 3 , 4 , 2 , -4 , 0 , 1 , 2 , 2 , -1 ,
-1 , 4 , 1 , 4 , -2 , -1 , -1 , -5 , 1 , -3 , 3 , 3 , -1 , -4 , 3 , -5 , 0 , 0 , -1 ,
-4 , -1 , -2 , 4 , -2 , 3 , 3 , -3 , 1 , -1 , 2 , -1 , 4 , 4 , -2 , -2 , 4 , -2 , 0 ,
3 , -3 , -5 , -1 , -2 , 4 , -4 , 2 , -4 , 0 , -2 , 3 , -3 , 2 , 2 , -2 , -5 , -1 , 4 ,
3 , -2 , -1 , 3 , 3 , -1 , 3 , 0 , -3 , 0 , 4 , 2 , 0 , -1 , 4 , 1 , 1 , 2 , 1 ,
3 , 1 , 1 , 1 , -3 , -5 , -4 , 4 , -4 , 2 , 0 , 0 , -4 , 1 , 4 , -5 , 4 , 4 , 0 ,
1 , 0 , -2 , -4 , -4 , -3 , 0 , 1 , -5 , 4 , 0 , -3 , -2 , -4 , 2 , 4 , 1 , -5 , 1 ,
-4 , 1 , 0 , -3 , -3 , 0 , 2 , -5 , 4 , 3 , -2 , -5 , 3 , 1 , -1 , 0 , 3 , -2 , -2 ,
3 , -2 , -5 , 4 , 1 , -2 , 2 , -1 , 0 , 4 , 0 , -5 , 3 , -2 , 1 , 2 , 1 , -5 , -3 ,
-2 , -5 , 4 , -4 , 0 , 3 , 2 , -1 , -4 , -1 , 2 , 1 , -2 , 3 , -1 , -4 , 2 , 0 , -3 ,
1 , -1 , 2 , -5 , -4 , -1 , -5 , 1 , 4 , 3 , 4 , 2 , -3 , 1 , -5 , -1 , 3 , 0 , -1 ,
-4 , 3 , 4 , -5 , 4 , 4 , -3 , 2 , -3 , -1 , -3 , -5 , -3 , 2 , -3 , -2 , 1 , 1 , 0 ,
-5 , 3 , 2 , 1 , -5 , 1 , 1 , 1 , 3 , 4 , -4 , -1 , -2 , 0 , -5 , -3 , -5 , -2 , -4 ,
3 , 3 , 3 , 4 , 0 , -4 , -1 , -5 , 0 , -3 , 1 , 4 , 4 , -4 , 4 , -5 , -5 , -1 , -2 ,
-5 , 3 , -4 , 4 , 3 , 0 , -3 , 2 , -2 , 0 , 0 , 4 , 4 , 0 , -2 , 1 , -1 , -3 , 2 ,
-1 , 1 , -3 , -5 ,
};
float bias_11x10[] = {
-5 ,
};
float expected_11x10[] = {
36 , -84 , 95 , 45 , 18 , 46 , 77 , -54 , -99 , -149 , 66 , 49 , 161 , 11 ,
39 , 61 , -66 , 61 , 4 , -3 , 34 , -44 , -23 , 31 , 64 , 29 , 47 , 72 ,
-27 , -27 , 121 , -3 , 100 , 1 , 30 , -78 , -12 , -89 , -59 , 8 , -16 , 112 ,
91 , -102 , -26 , -4 , 30 , 54 , 4 , -84 , -24 , -58 , 27 , -53 , -33 , 5 ,
53 , -26 , 63 , 50 , -103 , -130 , -23 , 6 , -104 , -207 , 73 , 23 , 77 , 132 ,
38 , 32 , -130 , -44 , -60 , 7 , 27 , 176 , 45 , -32 , -2 , 99 , -97 , 63 ,
69 , 126 , 47 , 63 , 136 , -57 , 5 , 16 , -40 , -157 , 8 , 38 , -44 , -10 ,
91 , 7 , 122 , 140 , 30 , -105 , 4 , -1 , 113 , 64 , 180 , 141 ,
};
cnn_config.layer_config[0 ].weights = weights_11x10;
cnn_config.layer_config[0 ].bias = bias_11x10;
cnn_config.layer_config[0 ].filter_width = 20 ;
cnn_config.layer_config[0 ].filter_height = 23 ;
cnn_config.layer_config[0 ].skip_width = 1 ;
cnn_config.layer_config[0 ].skip_height = 1 ;
image_height = 11 ;
image_width = 10 ;
RunCNNTest(image_width, image_height, input_11x10, expected_11x10,
&cnn_config, image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestSoftsignSingleLayer) {
int image_width = 8 ;
int image_height = 8 ;
int filter_height = 5 ;
int filter_width = 4 ;
float input[] = {
-0 .5220 f, 0 .8410 f, -0 .8990 f, -0 .0090 f, 0 .6710 f, -0 .9470 f, -0 .8240 f,
-0 .0870 f, 0 .5380 f, 0 .4750 f, 0 .570 f, -0 .3760 f, -0 .6960 f, -0 .5940 f,
-0 .3830 f, 0 .080 f, -0 .0980 f, -0 .4940 f, -0 .4030 f, 0 .9460 f, -0 .6020 f,
0 .4220 f, 0 .6190 f, 0 .6640 f, -0 .9210 f, -0 .1470 f, -0 .2480 f, -0 .1120 f,
-0 .580 f, -0 .0650 f, 0 .3330 f, 0 .9860 f, -0 .7430 f, 0 .7610 f, 0 .4840 f,
0 .1030 f, 0 .9570 f, 0 .6120 f, -0 .5240 f, -0 .1220 f, -0 .5850 f, -0 .270 f,
0 .7840 f, -0 .9790 f, 0 .7290 f, -0 .30 f, -0 .6460 f, 0 .0780 f, 0 .4750 f,
-0 .0510 f, 0 .4550 f, 0 .3850 f, -0 .7230 f, 0 .4460 f, -0 .6260 f, -0 .810 f,
0 .8720 f, -0 .2120 f, -0 .580 f, -0 .9510 f, -0 .8430 f, -0 .1340 f, -0 .0850 f,
0 .9190 f,
};
float expected_same[] = {
0 .430 f, 0 .660 f, 0 .5510 f, -0 .610 f, 0 .450 f, -0 .1610 f, 0 .0520 f, 0 .3240 f,
0 .6820 f, 0 .3820 f, 0 .6360 f, 0 .7480 f, 0 .3080 f, 0 .090 f, 0 .3910 f, 0 .1730 f,
0 .340 f, 0 .6660 f, -0 .4990 f, 0 .4280 f, 0 .1540 f, 0 .120 f, 0 .4670 f, 0 .6150 f,
-0 .3880 f, 0 .7590 f, 0 .4190 f, 0 .7350 f, 0 .5310 f, -0 .5160 f, -0 .1760 f, 0 .6790 f,
-0 .6780 f, 0 .5470 f, 0 .5750 f, -0 .6420 f, 0 .7210 f, -0 .4620 f, 0 .5430 f, 0 .770 f,
-0 .1990 f, 0 .3950 f, 0 .7860 f, -0 .4380 f, 0 .7540 f, 0 .2640 f, -0 .6430 f, 0 .4510 f,
-0 .1260 f, 0 .1590 f, -0 .2110 f, -0 .0560 f, 0 .6570 f, 0 .680 f, 0 .5870 f, 0 .4720 f,
0 .4040 f, 0 .3630 f, 0 .670 f, 0 .2360 f, 0 .410 f, 0 .6980 f, -0 .5350 f, 0 .3940 f,
};
float expected_replicate[] = {
0 .540 f, 0 .7230 f, -0 .3530 f, -0 .2130 f, 0 .7440 f, -0 .4470 f, -0 .6260 f,
-0 .2050 f, 0 .7230 f, 0 .4630 f, 0 .5920 f, 0 .7440 f, 0 .6080 f, 0 .3130 f,
-0 .5670 f, -0 .4720 f, 0 .5480 f, 0 .6660 f, -0 .4990 f, 0 .4280 f, 0 .1540 f,
0 .120 f, 0 .3390 f, 0 .6090 f, 0 .4160 f, 0 .7590 f, 0 .4190 f, 0 .7350 f,
0 .5310 f, -0 .5160 f, -0 .490 f, 0 .4450 f, -0 .610 f, 0 .5470 f, 0 .5750 f,
-0 .6420 f, 0 .7210 f, -0 .4620 f, 0 .3150 f, 0 .7370 f, -0 .5820 f, 0 .3950 f,
0 .7860 f, -0 .4380 f, 0 .7540 f, 0 .2640 f, -0 .7430 f, -0 .5340 f, -0 .6270 f,
0 .4430 f, 0 .4730 f, 0 .4570 f, 0 .7450 f, 0 .630 f, 0 .2620 f, 0 .3140 f,
-0 .1840 f, 0 .1810 f, 0 .7210 f, 0 .2760 f, 0 .6430 f, 0 .6720 f, -0 .4390 f,
0 .2040 f,
};
float expected_valid[] = {
0 .6660 f, -0 .4990 f, 0 .4280 f, 0 .1540 f, 0 .120 f, 0 .7590 f, 0 .4190 f,
0 .7350 f, 0 .5310 f, -0 .5160 f, 0 .5470 f, 0 .5750 f, -0 .6420 f, 0 .7210 f,
-0 .4620 f, 0 .3950 f, 0 .7860 f, -0 .4380 f, 0 .7540 f, 0 .2640 f,
};
float weights[] = {
0 .6210 f, 0 .3710 f, -0 .2770 f, -0 .7230 f, -0 .2450 f, 0 .6770 f, 0 .3080 f,
-0 .9880 f, -0 .080 f, 0 .7190 f, -0 .6760 f, -0 .0170 f, -0 .8970 f, 0 .8260 f,
0 .7390 f, -0 .4550 f, -0 .4260 f, -0 .6330 f, 0 .0880 f, -0 .9390 f,
};
float bias[] = {
0 .750 f,
};
CNN_CONFIG cnn_config = { 1 ,
0 ,
0 ,
0 ,
0 ,
{ {
1 ,
filter_width,
filter_height,
1 ,
1 ,
1 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
SOFTSIGN,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
} } };
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
cnn_config.layer_config[0 ].pad = PADDING_SAME_REPLICATE;
RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
cnn_config.layer_config[0 ].pad = PADDING_VALID;
RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
}
TEST_F(CNNTest, TestBranchTensorAdd) {
int filter_width = 2 ;
int filter_height = 3 ;
int image_width = 4 ;
int image_height = 4 ;
float input[] = {
-3 , -2 , -2 , 0 , -1 , 3 , 2 , -2 , 1 , 3 , 4 , 0 , 2 , -5 , -4 , 0 ,
};
float weights[] = {
-3 , -1 , 4 , -1 , -3 , 3 , 3 , 0 , 2 , 0 , 3 , 2 , 4 , 4 , 4 , -5 , 1 , -4 ,
2 , -4 , 1 , -3 , 0 , 4 , -5 , 4 , 0 , -4 , -3 , -1 , 0 , 0 , -2 , 0 , 0 , 2 ,
-5 , -1 , 1 , -3 , 3 , 4 , 3 , 0 , 1 , -1 , 1 , 1 , 2 , 4 , -2 , -5 , 2 , -2 ,
3 , -2 , 4 , -1 , 0 , 2 , 3 , 2 , -2 , -1 , -3 , 1 , 3 , 4 , -1 , -3 , 0 , -4 ,
4 , 2 , -3 , -3 , -1 , 0 , 1 , 0 , 3 , 3 , -3 , 0 , 3 , 2 , -5 , -3 , 4 , -5 ,
3 , -1 , -1 , -3 , 0 , 1 , -1 , -4 , 2 , 4 , -1 , 4 , -1 , 1 , 3 , 4 , 4 , 4 ,
0 , -1 , -3 , -3 , -3 , -3 , 2 , -3 , -2 , 2 , 3 , -3 ,
};
float bias[] = {
3 , 4 , -1 , -1 , 2 , 1 , -2 , 1 , 4 , 1 , 3 ,
};
float expected[] = {
-11502 , -4101 , -3424 , 668 , -17950 , -5470 , -5504 , 626 ,
4835 , 446 , 1779 , -3483 , 3679 , -4214 , 4578 , -105 ,
};
int channels = 2 ;
CNN_CONFIG cnn_config = { 6 ,
0 ,
0 ,
0 ,
0 ,
{ {
1 ,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_INPUT,
BRANCH_NOC,
{
0 x02,
0 ,
0 x00,
},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
1 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
1 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_ADD,
{
0 x00,
0 ,
0 x02,
},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
1 ,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
} } };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestBranchTensorConcatenation) {
int filter_width = 2 ;
int filter_height = 3 ;
int image_width = 4 ;
int image_height = 4 ;
float input[] = {
-3 , -2 , -2 , 0 , -1 , 3 , 2 , -2 , 1 , 3 , 4 , 0 , 2 , -5 , -4 , 0 ,
};
float weights[] = {
3 , 0 , 2 , 0 , 2 , 3 , 1 , -3 , 1 , -5 , -3 , 0 , -4 , 4 , 0 , -5 , 0 , -5 , -1 ,
-2 , -5 , 0 , -3 , 2 , -4 , 2 , 0 , 2 , -1 , 0 , -4 , 3 , 0 , 0 , -1 , -5 , 2 , -1 ,
4 , -4 , -2 , -3 , -3 , 3 , 4 , -2 , -1 , -4 , -1 , 4 , 4 , -1 , 4 , 3 , -4 , 2 , -2 ,
-4 , -3 , -2 , 3 , -3 , -5 , -1 , 3 , -2 , 4 , 1 , -4 , -3 , -5 , -5 , -3 , 4 , -2 , -2 ,
-1 , -5 , -5 , 0 , -1 , -2 , -3 , 3 , -4 , -5 , 2 , -3 , 1 , 0 , -5 , 2 , 2 , -2 , 0 ,
2 , 2 , -2 , 4 , 2 , 2 , 0 , 1 , -5 , -3 , 0 , 2 , -2 , 1 , 2 , -5 , 2 , 3 , 3 ,
-1 , 3 , 0 , -3 , 3 , -4 , -4 , 3 , 3 , -4 , -2 , 2 , -2 , 2 , -2 , -1 , 3 , 0 ,
};
float bias[] = {
-3 , -5 , 4 , -4 , -3 , -2 , 0 , 3 , -4 , 4 , -3 ,
};
float expected[] = {
-33533 , -32087 , -6741 , -2124 , 39979 , 41453 , 14034 , 689 ,
-22611 , -42203 , -14882 , -239 , 15781 , 15963 , 9524 , 837 ,
};
int channels = 2 ;
CNN_CONFIG cnn_config = { 6 ,
0 ,
0 ,
0 ,
0 ,
{ {
1 ,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_INPUT,
BRANCH_NOC,
{
0 x02,
0 ,
0 x00,
},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
1 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
1 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_CAT,
{
0 x00,
0 ,
0 x02,
},
{},
-1 ,
},
{
channels + channels,
filter_width,
filter_height,
1 ,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
} } };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
// TODO(logangw): Add test to test all combinations of branch_copy_type.
TEST_F(CNNTest, TestBranchCombinations) {
int filter_width = 2 ;
int filter_height = 3 ;
int image_width = 4 ;
int image_height = 4 ;
float input[] = {
3 , 2 , -5 , -4 , 4 , -2 , -4 , -3 , 4 , 2 , -3 , 2 , -3 , 1 , -5 , -1 ,
};
float weights[] = {
2 , 3 , 0 , 4 , 4 , 3 , 1 , 0 , 1 , -5 , 4 , -3 , 3 , 0 , 4 , -1 , -1 , -5 ,
2 , 1 , -3 , -5 , 3 , -1 , -3 , -2 , 0 , -2 , 3 , 0 , -2 , -4 , -2 , -2 , 2 , -5 ,
4 , -5 , 0 , 1 , -5 , -4 , -3 , -4 , 2 , -2 , 1 , 0 , 3 , -2 , -4 , 3 , 4 , -4 ,
-1 , -1 , -3 , -2 , -2 , -1 , 2 , 0 , 2 , -1 , 2 , -4 , -4 , -1 , 2 , 0 , 3 , -2 ,
-2 , 3 , -3 , 4 , -2 , 4 , 3 , 4 , 1 , 0 , -2 , -3 , -5 , 1 , -3 , 2 , 0 , -2 ,
-2 , -1 , -1 , -5 , -2 , -3 , -1 , 3 , 3 , 4 , 4 , 0 , 2 , 1 , 3 , -3 , 2 , -5 ,
-5 , 1 , -5 , -1 , 3 , 3 , 2 , -4 , -1 , 3 , -4 , -2 , -5 , -2 , 1 , 3 , 2 , 2 ,
-5 , -2 , -3 , -1 , -2 , -4 , -1 , -2 , 2 , 1 , -4 , -4 , 2 , 0 , 2 , 0 , 2 , -3 ,
-2 , -4 , 4 , 0 , 1 , -3 , -5 , 4 , -1 , 2 , 3 , -5 , -1 , 0 , 4 , -1 , -1 , 3 ,
-1 , -3 , 3 , 1 , 4 , 3 , 4 , 3 , -4 , -5 , -1 , 3 , 3 , -4 , 3 , 1 , 3 , -5 ,
3 , 4 , -5 , 4 , 2 , -1 , -5 , 2 , 1 , 0 , 4 , 0 , -3 , 2 , 0 , 2 , -2 , 1 ,
-1 , -2 , -1 , -5 , 4 , 3 , 3 , -2 , 2 , 4 , -5 , -5 , -3 , -2 , 4 , 0 , -4 , 1 ,
};
float bias[] = {
-1 , 4 , 0 , 2 , 2 , -2 , 0 , -4 , -5 , -1 , 1 , -2 , 3 , 0 , 4 , -2 , 1 , 0 , 0 ,
};
float expected[] = {
149496 , 15553 , -24193 , -20956 , 134094 , 86432 , -68283 , -6366 ,
-53031 , 133739 , 67407 , -13539 , -53205 , -58635 , -20033 , 1979 ,
};
int channels = 2 ;
CNN_CONFIG cnn_config = { 10 ,
0 ,
0 ,
0 ,
0 ,
{
{
1 ,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_INPUT,
BRANCH_NOC,
{
0 x06,
0 ,
0 x00,
},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
2 ,
BRANCH_OUTPUT,
BRANCH_NOC,
{
0 x08,
0 ,
0 x00,
},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
3 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
2 ,
BRANCH_NO_COPY,
BRANCH_ADD,
{
0 x00,
0 ,
0 x08,
},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
2 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
1 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
1 ,
BRANCH_NO_COPY,
BRANCH_ADD,
{
0 x00,
0 ,
0 x0C,
},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
channels,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_ADD,
{
0 x00,
0 ,
0 x02,
},
{},
-1 ,
},
{
channels,
filter_width,
filter_height,
1 ,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
},
} };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestSplittingTensors) {
int filter_width = 2 ;
int filter_height = 3 ;
int image_width = 4 ;
int image_height = 4 ;
float input[] = {
-1 , -1 , 2 , 1 , 3 , 2 , 4 , -3 , -4 , -2 , 2 , -3 , 1 , -3 , 4 , -2 ,
};
float weights[] = {
-4 , 1 , 0 , 2 , 3 , 4 , 4 , -4 , -5 , -3 , 2 , 2 , -4 , -3 , 3 , 2 ,
4 , -4 , -3 , -4 , -4 , 1 , -3 , -5 , -3 , 4 , 2 , -2 , 2 , -1 , -4 , -1 ,
-2 , -3 , 1 , 1 , 0 , -5 , -1 , 3 , 3 , -5 , -3 , 0 , -3 , 1 , -3 , -1 ,
1 , -3 , -2 , -2 , 4 , -2 , 0 , 1 , 2 , 2 , -4 , 2 , 4 , 0 , -5 , -2 ,
4 , 4 , -5 , 1 , 0 , 2 , -2 , -5 , -5 , -3 , -5 , -5 , 4 , -3 , 0 , 0 ,
-4 , -4 , 0 , -5 , -4 , 0 , 0 , -3 , -5 , -3 , -1 , 2 , -1 , 4 , -1 , 2 ,
};
float bias[] = {
-4 , -2 , -3 , -3 , 3 , 1 , -2 ,
};
float expected[] = {
530 , -762 , 1469 , 777 , 849 , -771 , -1698 , 600 ,
-658 , -1821 , 98 , -668 , -1798 , 30 , 887 , -971 ,
};
CNN_CONFIG cnn_config = { 3 ,
0 ,
0 ,
0 ,
0 ,
{
{
1 ,
filter_width,
filter_height,
4 ,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_OUTPUT,
BRANCH_NOC,
{
0 x02,
2 ,
0 x00,
},
{},
-1 ,
},
{
4 ,
filter_width,
filter_height,
2 ,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_CAT,
{
0 x00,
0 ,
0 x02,
},
{},
-1 ,
},
{
4 ,
filter_width,
filter_height,
1 ,
1 ,
1 ,
0 ,
nullptr,
nullptr,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
},
} };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_INT_TOL);
}
TEST_F(CNNTest, TestOutputChannelsCount) {
int filter_width = 1 ;
int filter_height = 1 ;
int image_width = 2 ;
int image_height = 2 ;
float input[] = { 0 , 0 , 0 , 0 };
float weights[] = { 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 };
float bias[] = { 0 , 0 , 0 , 0 , 0 , 0 };
float expected[] = {
0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
};
CNN_CONFIG cnn_config = { 3 ,
0 ,
0 ,
0 ,
0 ,
{
{
1 ,
filter_width,
filter_height,
2 ,
1 ,
1 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_INPUT,
BRANCH_NOC,
{
0 x06,
0 ,
0 x00,
},
{},
-1 ,
},
{
1 ,
filter_width,
filter_height,
2 ,
1 ,
1 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0 ,
2 ,
BRANCH_NO_COPY,
BRANCH_CAT,
{
0 x00,
0 ,
0 x03,
},
{},
-1 ,
},
{
2 ,
filter_width,
filter_height,
2 ,
1 ,
1 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_CAT,
{
0 x00,
0 ,
0 x04,
},
{},
0 ,
},
} };
// Weights and biases need to be specified separately because
// of the offset.
AssignLayerWeightsBiases(&cnn_config, weights, bias);
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
}
TEST_F(CNNTest, TestBatchNorm) {
int image_width = 28 ;
int image_height = 28 ;
int filter_height = 7 ;
int filter_width = 7 ;
float input[] = {
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0117647 f, 0 .0705882 f, 0 .0705882 f, 0 .0705882 f,
0 .494118 f, 0 .533333 f, 0 .686275 f, 0 .101961 f, 0 .65098 f, 1 .0 f,
0 .968627 f, 0 .498039 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .117647 f, 0 .141176 f, 0 .368627 f, 0 .603922 f,
0 .666667 f, 0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .992157 f,
0 .882353 f, 0 .67451 f, 0 .992157 f, 0 .94902 f, 0 .764706 f, 0 .25098 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .192157 f,
0 .933333 f, 0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .992157 f,
0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .984314 f, 0 .364706 f, 0 .321569 f,
0 .321569 f, 0 .219608 f, 0 .152941 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0705882 f, 0 .858824 f, 0 .992157 f,
0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .776471 f, 0 .713725 f,
0 .968627 f, 0 .945098 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .313725 f, 0 .611765 f, 0 .419608 f, 0 .992157 f,
0 .992157 f, 0 .803922 f, 0 .0431373 f, 0 .0 f, 0 .168627 f, 0 .603922 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .054902 f, 0 .00392157 f, 0 .603922 f, 0 .992157 f, 0 .352941 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .545098 f, 0 .992157 f, 0 .745098 f, 0 .00784314 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0431373 f,
0 .745098 f, 0 .992157 f, 0 .27451 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .137255 f, 0 .945098 f,
0 .882353 f, 0 .627451 f, 0 .423529 f, 0 .00392157 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .317647 f, 0 .941176 f, 0 .992157 f,
0 .992157 f, 0 .466667 f, 0 .0980392 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .176471 f, 0 .729412 f, 0 .992157 f, 0 .992157 f,
0 .588235 f, 0 .105882 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0627451 f, 0 .364706 f, 0 .988235 f, 0 .992157 f, 0 .733333 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .976471 f, 0 .992157 f, 0 .976471 f, 0 .25098 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .180392 f, 0 .509804 f, 0 .717647 f, 0 .992157 f,
0 .992157 f, 0 .811765 f, 0 .00784314 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .152941 f, 0 .580392 f,
0 .898039 f, 0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .980392 f, 0 .713725 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0941176 f, 0 .447059 f, 0 .866667 f, 0 .992157 f, 0 .992157 f, 0 .992157 f,
0 .992157 f, 0 .788235 f, 0 .305882 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0901961 f, 0 .258824 f, 0 .835294 f, 0 .992157 f,
0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .776471 f, 0 .317647 f, 0 .00784314 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0705882 f, 0 .670588 f,
0 .858824 f, 0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .992157 f, 0 .764706 f,
0 .313725 f, 0 .0352941 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .215686 f, 0 .67451 f, 0 .886275 f, 0 .992157 f, 0 .992157 f, 0 .992157 f,
0 .992157 f, 0 .956863 f, 0 .521569 f, 0 .0431373 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .533333 f, 0 .992157 f,
0 .992157 f, 0 .992157 f, 0 .831373 f, 0 .529412 f, 0 .517647 f, 0 .0627451 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f,
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f
};
float expected[] = {
-0 .836424 f, -0 .857365 f, -1 .62739 f, -1 .62739 f, -0 .836424 f, 5 .40742 f,
0 .920853 f, -0 .692567 f, -0 .836424 f, -0 .534405 f, -1 .62739 f, -0 .836424 f,
1 .32602 f, 1 .36312 f, 0 .112766 f, -0 .836424 f, -0 .192962 f, 1 .56975 f,
2 .45777 f, 0 .944414 f, -0 .192962 f, -1 .5519 f, -1 .5519 f, -0 .554006 f,
-0 .192962 f, 1 .4231 f, -1 .5519 f, -0 .192962 f, 1 .3661 f, -1 .5519 f,
-1 .5519 f, -0 .192962 f, -0 .843708 f, -0 .359025 f, -0 .843708 f, -0 .843708 f,
-0 .843708 f, 4 .53065 f, 0 .0429584 f, -0 .796804 f, -0 .843708 f, 0 .3473 f,
-0 .843708 f, -0 .843708 f, -0 .114439 f, 3 .14817 f, 0 .0811934 f, -0 .843708 f
};
float kernel[] = {
0 .119643 f, -0 .237864 f, 0 .0462892 f, 0 .0502297 f, -0 .0134528 f,
0 .146347 f, 0 .153133 f, 0 .0513307 f, 0 .0752369 f, 0 .0135557 f,
-0 .111434 f, 0 .0941854 f, 0 .0788362 f, 0 .0299412 f, 0 .111762 f,
0 .144066 f, 0 .00431504 f, -0 .0177954 f, 0 .0738092 f, -0 .0344215 f,
0 .0832582 f, 0 .053989 f, -0 .112691 f, 0 .0962145 f, 0 .0186525 f,
-0 .00660205 f, -0 .111962 f, -0 .126801 f, -0 .231625 f, 0 .17309 f,
0 .0748875 f, -0 .179569 f, -0 .00513812 f, -0 .156579 f, -0 .147322 f,
0 .184168 f, 0 .189308 f, -0 .200359 f, -0 .0156733 f, 0 .140649 f,
0 .0858496 f, -0 .0263217 f, -0 .0740749 f, -0 .112563 f, 0 .107528 f,
0 .0609729 f, -0 .221625 f, 0 .0769944 f, -0 .00900815 f, -0 .00136441 f,
-0 .0236521 f, -0 .0418025 f, -0 .00286299 f, 0 .12241 f, 0 .0964093 f,
-0 .0150897 f, 0 .0532171 f, 0 .0625916 f, 0 .116939 f, 0 .118024 f,
0 .161918 f, -0 .00909767 f, 0 .100897 f, -0 .054563 f, -0 .175179 f,
-0 .0687892 f, 0 .00734235 f, 0 .109833 f, -0 .113776 f, 0 .0595405 f,
-0 .170255 f, 0 .0124815 f, -0 .0363301 f, -0 .0127038 f, 0 .0445554 f,
-0 .0729894 f, 0 .107428 f, -0 .0341417 f, 0 .132619 f, 0 .00984557 f,
-0 .00443654 f, 0 .202929 f, 0 .0945134 f, 0 .0148725 f, 0 .00998574 f,
-0 .0226449 f, 0 .0478197 f, -0 .0793442 f, 0 .0707599 f, -0 .084225 f,
0 .0865795 f, 0 .071104 f, -0 .047894 f, 0 .0838322 f, 0 .0635493 f,
-0 .00370265 f, -0 .157247 f, -0 .0289622 f, -0 .0590963 f, 0 .13207 f,
0 .00468011 f, -0 .0345372 f, 0 .217939 f, 0 .18861 f, -0 .0290393 f,
-0 .0440664 f, 0 .0126197 f, -0 .129132 f, -0 .124943 f, 0 .0968156 f,
-0 .0853643 f, -0 .182305 f, 0 .00461618 f, -0 .147095 f, -0 .230282 f,
0 .00856019 f, 0 .0278893 f, -0 .0300229 f, 0 .0417871 f, 0 .0804717 f,
-0 .0768571 f, -0 .0397085 f, -0 .0601096 f, 0 .100901 f, -0 .0184926 f,
0 .0350673 f, 0 .0971094 f, -0 .0171837 f, -0 .289644 f, -0 .0899041 f,
0 .08998 f, -0 .160319 f, -0 .0195103 f, 0 .0392167 f, -0 .137864 f,
-0 .0136294 f, 0 .0330886 f, -0 .0409244 f, -0 .092533 f, -0 .0427934 f,
-0 .191144 f, -0 .0969461 f, 0 .112035 f, 0 .138611 f, 0 .128717 f,
0 .191184 f, 0 .197462 f
};
float bias[] = { 0 .186703 f, 0 .204358 f, -0 .0230452 f };
float bn_gamma[] = { 1 .32173 f, 1 .26171 f, 1 .21966 f };
float bn_beta[] = { -0 .232595 f, -0 .222652 f, -0 .232209 f };
float bn_mean[] = { 0 .329233 f, 0 .199894 f, 0 .12389 f };
float bn_std[] = { 0 .311986 f, 0 .189737 f, 0 .247104 f };
CNN_BATCHNORM_PARAMS bn_params = {
bn_gamma,
bn_beta,
bn_mean,
bn_std,
};
CNN_CONFIG cnn_config = {
1 ,
0 ,
0 ,
0 ,
0 ,
{
{
1 ,
filter_width,
filter_height,
3 ,
7 ,
7 ,
0 ,
kernel,
bias,
PADDING_VALID,
RELU,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
bn_params,
0 ,
},
},
};
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
}
TEST_F(CNNTest, TestMultithreading) {
int image_height = 2 ;
int image_width = 2 ;
int filter_height = 3 ;
int filter_width = 3 ;
float input[] = {
-2 ,
4 ,
1 ,
0 ,
};
float weights[] = {
-4 , 2 , -2 , 0 , -4 , 4 , -3 , -3 , -3 , -1 , 1 , 0 , -5 , -3 , 0 , -5 , 0 , 0 ,
-1 , 0 , 2 , -5 , 0 , 1 , 4 , 2 , 1 , 0 , -2 , -1 , -5 , -3 , 2 , -2 , 1 , -5 ,
};
float bias[] = {
-4 ,
-3 ,
-2 ,
3 ,
};
float expected[] = {
2 , 10 , -8 , -17 , -24 , 5 , -15 , 6 , -5 , -5 , 7 , -10 , 4 , 13 , 9 , -14 ,
};
CNN_CONFIG cnn_config = {
1 ,
0 ,
0 ,
0 ,
0 ,
{
{
1 ,
filter_width,
filter_height,
4 ,
1 ,
1 ,
0 ,
weights,
bias,
PADDING_SAME_ZERO,
NONE,
0 ,
0 ,
BRANCH_NO_COPY,
BRANCH_NOC,
{},
{},
0 ,
},
},
};
CNN_THREAD_DATA thread_data = { 1 , nullptr };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
const AVxWorkerInterface *const winterface = aom_get_worker_interface();
AVxWorker workers[4 ];
for (int i = 0 ; i < 4 ; ++i) {
winterface->init(&workers[i]);
}
thread_data = { 4 , workers };
RunCNNTest(image_width, image_height, input, expected, &cnn_config,
image_width, &thread_data, MSE_FLOAT_TOL);
for (int i = 0 ; i < 4 ; ++i) {
winterface->end(&workers[i]);
}
}
TEST_F(CNNTest, TestMultiOutput) {
const int image_dim = 8 ;
const int image_ch = 3 ;
const int filter_dim = 2 ;
const int stride = 2 ;
const int num_filters = 2 ;
const float input_[] = {
1 .7537929121 f, 0 .134331551012 f, 0 .123580039877 f, 0 .957731845246 f,
0 .391006834217 f, 1 .00699352042 f, -0 .778177955829 f, -0 .814166433059 f,
-0 .656374394915 f, 0 .321967305228 f, -2 .19455719176 f, 0 .708035038966 f,
0 .409148822266 f, -0 .318254408902 f, 0 .152450211189 f, -0 .250210793369 f,
0 .826811563186 f, 1 .6804156584 f, 0 .273626975978 f, 0 .437936241887 f,
-0 .329935520167 f, -0 .288761611645 f, 0 .156937008304 f, 0 .271054157295 f,
-0 .0224828854332 f, 1 .70110336895 f, -0 .989066699309 f, 1 .30863131729 f,
-0 .165813705702 f, 0 .00380178619265 f, -0 .0837342367587 f, 0 .760954783156 f,
-0 .413610373524 f, 1 .17968204175 f, 0 .720295719536 f, 0 .308718974472 f,
-1 .10091337671 f, 0 .693160033687 f, -0 .0202862320697 f, 1 .0221927503 f,
-1 .24521801881 f, -0 .478501952308 f, -1 .71648619442 f, -0 .182571723636 f,
0 .339292649504 f, 2 .0806519131 f, 0 .967974033444 f, 0 .175248672328 f,
0 .0658124561472 f, 0 .795504169496 f, 0 .750592557361 f, -1 .46631013249 f,
-1 .79052846838 f, -1 .03672179515 f, -0 .841985521653 f, 1 .20995011489 f,
0 .140859718215 f, -0 .651552622661 f, 0 .451065110806 f, 1 .1189443693 f,
0 .100213260593 f, -0 .834076868118 f, -1 .28734321611 f, 1 .22064420095 f,
-0 .364143084361 f, 0 .750961509335 f, -0 .888689074553 f, -0 .8253547106 f,
-1 .21800999027 f, -0 .966670603566 f, 1 .37384014741 f, 0 .47281264834 f,
-0 .420416235531 f, 0 .520163906493 f, 0 .501296589423 f, 1 .53418976951 f,
0 .715234751485 f, 0 .644551588907 f, 0 .0763504863375 f, -0 .0018541943723 f,
0 .322853189656 f, -0 .795099723224 f, -0 .125177096675 f, 1 .4476577471 f,
-0 .585888410088 f, -1 .44391754955 f, -0 .610543221933 f, -0 .221859179799 f,
0 .252060200774 f, -0 .86287169623 f, -0 .0350246229157 f, 1 .0932311997 f,
0 .899464648842 f, -0 .468806951704 f, -0 .300861137168 f, 1 .15776414206 f,
1 .03268544738 f, -0 .171579585622 f, -0 .179136557119 f, -0 .354091003368 f,
-0 .612298249394 f, -1 .20237379258 f, 1 .54604109659 f, 0 .130664370287 f,
0 .885225111868 f, 1 .0362799581 f, 0 .980561720868 f, -0 .619379186999 f,
-1 .33818929924 f, -0 .237233737961 f, -1 .89335425073 f, 0 .567821011321 f,
0 .862420368465 f, -1 .37380916821 f, 0 .352190056666 f, 0 .611261516274 f,
0 .393237747152 f, 0 .894686247967 f, 0 .190405182149 f, 0 .264872662911 f,
-0 .0657009133797 f, 0 .0580512653493 f, -0 .401825294366 f, 0 .4106081318 f,
0 .49484512188 f, -0 .0751103149442 f, -1 .43243736382 f, 1 .79855656009 f,
-1 .1075351975 f, 0 .000354882733011 f, -0 .950716438608 f, 1 .27129831688 f,
1 .00495189838 f, 0 .110358656713 f, 1 .08315032822 f, -0 .972676676218 f,
-0 .0757668962831 f, 1 .88932045165 f, -0 .0672638136275 f, 0 .425913010161 f,
-0 .781540372017 f, 0 .976000248609 f, 0 .687218504122 f, 1 .31374513445 f,
-0 .932658930672 f, -1 .25339468479 f, 0 .422071294078 f, -0 .24189927912 f,
0 .216906604642 f, -1 .88720997548 f, 1 .99252872889 f, 0 .353943735777 f,
0 .737434784132 f, -1 .17848645017 f, 1 .70424254896 f, 0 .775297112968 f,
-0 .516392797501 f, 0 .398130609129 f, 0 .737248101457 f, 0 .166282500886 f,
1 .24699015468 f, 0 .47116183125 f, 1 .19091180182 f, -0 .372695424578 f,
0 .219773209389 f, -0 .829467838962 f, -0 .52533122724 f, 1 .98707754595 f,
0 .553692606972 f, -0 .933228902369 f, 1 .55427751643 f, -1 .08813399144 f,
-0 .325686682094 f, 0 .205091443796 f, -1 .70381666435 f, 0 .466465327942 f,
1 .73126863447 f, -0 .939133672634 f, 1 .48318077459 f, -0 .599414038168 f,
-1 .1583078687 f, 0 .518116190201 f, 0 .133571482458 f, 0 .84958342672 f,
1 .02205000597 f, -0 .0772082009087 f, -1 .69567503859 f, 1 .4697939436 f,
1 .67813743122 f, -0 .627911582938 f, 0 .131380509137 f, -1 .35717850726 f,
};
const float *input[3 ] = { input_, &input_[image_dim * image_dim],
&input_[2 * image_dim * image_dim] };
const float bias[] = { 0 .0 f, 0 .0 f };
const float weights_1[] = {
-0 .489547413618 f, 0 .141916424749 f, -0 .279286485585 f, -0 .115322211094 f,
0 .299572786936 f, 0 .205289980785 f, -0 .536254480088 f, -0 .253626313744 f,
-0 .422883815849 f, -0 .169702966298 f, -0 .540104704793 f, 0 .495319646763 f,
0 .298799079422 f, -0 .10054550901 f, -0 .306085047056 f, 0 .171061886165 f,
-0 .108058703878 f, -0 .410734629888 f, -0 .0640674673049 f, -0 .386524840979 f,
-0 .157203423678 f, -0 .362138920529 f, -0 .216206085209 f, 0 .147502517971 f,
};
const float weights_2[] = {
0 .207580604357 f, 0 .480821146263 f, -0 .29111909562 f, 0 .47422567493 f,
0 .206892553253 f, -0 .235067084092 f, 0 .354516800602 f, -0 .212399370252 f,
-0 .419071343731 f, -0 .050350731631 f, -0 .0516457320279 f, -0 .0359310500731 f,
0 .567044864811 f, -0 .060341127522 f, 0 .0501464839637 f, -0 .437785677916 f,
};
const float weights_3[] = {
-0 .0690452401448 f, -0 .356657338763 f, -0 .219464031809 f, 0 .551288365843 f,
0 .181372090853 f, -0 .00245268542109 f, 0 .409000696276 f, -0 .593209108763 f,
0 .587352566749 f, -0 .243720660227 f, 0 .266232713887 f, -0 .00439285245097 f,
0 .252883228305 f, 0 .152646192631 f, 0 .0918944932026 f, 0 .398853715057 f,
};
const float weights_4[] = {
0 .207560791573 f, 0 .194201350401 f, 0 .227802322443 f, 0 .206533663345 f,
0 .0557331066805 f, 0 .0224159800424 f, -0 .143939197467 f, -0 .27703361602 f,
0 .130643888389 f, -0 .269456557461 f, 0 .186242862864 f, -0 .162879944774 f,
-0 .145503996718 f, -0 .0768822987581 f, -0 .203127976359 f, -0 .238119922873 f,
-0 .258806479994 f, 0 .0357957680385 f, -0 .1027606976 f, -0 .287920082345 f,
0 .189047820993 f, 0 .250711538481 f, -0 .272815714175 f, -0 .0431449742024 f,
0 .207261230996 f, -0 .0396472677451 f, 0 .131236557412 f, 0 .174291832499 f,
-0 .251515885765 f, -0 .107164007499 f, 0 .185824534748 f, -0 .00561585838161 f,
0 .273393799578 f, -0 .139563699075 f, -0 .263922456031 f, -0 .118859844081 f,
0 .109230982597 f, -0 .170170294794 f, 0 .0123025648515 f, -0 .0839368964355 f,
-0 .0774058234297 f, 0 .255847138286 f, -0 .208430879637 f, 0 .279170114319 f,
-0 .272890330712 f, -0 .217725903006 f, -0 .295923275459 f, -0 .17008723953 f,
-0 .284281803405 f, 0 .281406323629 f, 0 .266910044663 f, -0 .209963914338 f,
0 .271980962964 f, 0 .142013581699 f, -0 .143896509026 f, -0 .290509242975 f,
-0 .305768180935 f, 0 .196902832117 f, -0 .090424189662 f, -0 .147460802346 f,
0 .217722016651 f, 0 .12353848977 f, -0 .169177363577 f, -0 .0454230918512 f,
};
const float expected_0[] = {
-2 .04858441055 f, -2 .12883075791 f, -0 .045177363807 f, 0 .763949675768 f,
-0 .544361512821 f, -1 .58123168032 f, 1 .89319847039 f, 0 .16859080901 f,
-1 .16023321135 f, -0 .396988107751 f, 1 .76637090744 f, -1 .40434786514 f,
0 .908227575669 f, 0 .817064817605 f, 0 .215631134908 f, -0 .848605613428 f,
-0 .106756747018 f, 0 .0193027166685 f, 0 .801345615113 f, -0 .395407237598 f,
-1 .79983795658 f, -1 .73054496242 f, 0 .0584392594454 f, -0 .388786095569 f,
-0 .237269619354 f, 0 .000843578271263 f, -1 .24043512104 f, 0 .487839445893 f,
-0 .394259726605 f, 0 .559632843424 f, -0 .527224052291 f, -1 .53792340282 f,
};
const float expected_1[] = {
0 .0 f, 0 .0 f, 0 .0 f, 0 .0 f, 0 .4057888292 f, 0 .325309571755 f,
0 .0 f, 1 .22013465602 f,
};
const float expected_2[] = {
0 .156119444687 f,
0 .517385299817 f,
};
const float expected_3[] = {
0 .224177852984 f,
0 .503384419034 f,
0 .156119444687 f,
0 .517385299817 f,
};
const float *expected[] = { expected_0, expected_1, expected_2, expected_3 };
CNN_CONFIG cnn_config = {
4 , // num_layers
0 , // is_residue
0 , // ext_width
0 , // ext_height
0 , // strict_bounds
{
// layer_config
{
image_ch, // in_channels
filter_dim, // filter_width
filter_dim, // filter_height
num_filters, // out_channels
stride, // skip_width
stride, // skip_height
0 , // max_pool
weights_1, // weights
bias, // bias
PADDING_SAME_ZERO, // pad
NONE, // activation
0 , // deconvolve
0 , // branch
BRANCH_OUTPUT, // branch_copy_type
BRANCH_NOC, // branch_combine_type
{ 2 , 0 , 0 }, // branch_config
{}, // bn_params
0 , // output_num
},
{
num_filters, // in_channels
filter_dim, // filter_width
filter_dim, // filter_height
num_filters, // out_channels
stride, // skip_width
stride, // skip_height
0 , // max_pool
weights_2, // weights
bias, // bias
PADDING_SAME_ZERO, // pad
RELU, // activation
0 , // deconvolve
0 , // branch
BRANCH_NO_COPY, // branch_copy_type
BRANCH_NOC, // branch_combine_type
{}, // branch_config
{}, // bn_params
1 , // output_num
},
{
num_filters, // in_channels
filter_dim, // filter_width
filter_dim, // filter_height
num_filters, // out_channels
stride, // skip_width
stride, // skip_height
0 , // max_pool
weights_3, // weights
bias, // bias
PADDING_SAME_ZERO, // pad
RELU, // activation
0 , // deconvolve
0 , // branch
BRANCH_NO_COPY, // branch_copy_type
BRANCH_NOC, // branch_combine_type
{}, // branch_config
{}, // bn_params
2 , // output_num
},
{
num_filters, // in_channels
2 * filter_dim, // filter_width
2 * filter_dim, // filter_height
num_filters, // out_channels
2 * stride, // skip_width
2 * stride, // skip_height
0 , // max_pool
weights_4, // weights
bias, // bias
PADDING_VALID, // pad
RELU, // activation
0 , // deconvolve
1 , // branch
BRANCH_NO_COPY, // branch_copy_type
BRANCH_CAT, // branch_combine_type
{ 0 , 0 , 1 }, // branch_config
{}, // bn_params
3 , // output_num
},
},
};
CNN_THREAD_DATA thread_data = { 1 , nullptr };
const int num_outputs = 4 ;
const int output_chs[4 ] = { filter_dim, filter_dim, filter_dim,
2 * filter_dim };
const int output_dims[4 ] = { 4 , 2 , 1 , 1 };
const int output_sizes[4 ] = {
output_chs[0 ] * output_dims[0 ] * output_dims[0 ],
output_chs[1 ] * output_dims[1 ] * output_dims[1 ],
output_chs[2 ] * output_dims[2 ] * output_dims[2 ],
output_chs[3 ] * output_dims[3 ] * output_dims[3 ],
};
float *const output_ = (float *)aom_malloc(
sizeof (*output_) *
(output_sizes[0 ] + output_sizes[1 ] + output_sizes[2 ] + output_sizes[3 ]));
ASSERT_NE(output_, nullptr);
float *output[CNN_MAX_CHANNELS] = { nullptr };
int ch_ite = 0 ;
float *output_ite = output_;
for (int output_idx = 0 ; output_idx < num_outputs; output_idx++) {
for (int channel = 0 ; channel < output_chs[output_idx]; ++channel) {
output[ch_ite++] = output_ite;
output_ite += output_dims[output_idx] * output_dims[output_idx];
}
}
CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_dims,
output };
RunMultiOutCNNTest(input, image_dim, image_dim, image_dim, &cnn_config,
&thread_data, &output_struct, expected, MSE_FLOAT_TOL);
aom_free(output_);
}
namespace {
typedef void (*CNNConvolveNoMaxpoolPaddingValidFunc)(
const float **input, int in_width, int in_height, int in_stride,
const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
int start_idx, int cstep, int channel_step);
typedef libaom_test::FuncParam<CNNConvolveNoMaxpoolPaddingValidFunc>
CNNConvolveTestFuncs;
class CNNConvolveTest : public ::testing::TestWithParam<CNNConvolveTestFuncs> {
protected :
void SetUp() override { params_ = GetParam(); }
void RunCNNConvolveSetup(int run_times) {
int in_width = 65 ;
int in_height = 65 ;
const CNN_CONFIG *cnn_config = &av1_intra_mode_cnn_partition_cnn_config;
for (int layer = 0 ; layer < cnn_config->num_layers; ++layer) {
int out_width = 0 , out_height = 0 ;
int in_size = in_width * in_height;
// Get current layer output width and height.
av1_find_cnn_layer_output_size(in_height, in_width,
&cnn_config->layer_config[layer],
&out_width, &out_height);
int out_size = out_width * out_height;
float *input[20 ], *output_ref[20 ], *output_mod[20 ];
float *input_data =
(float *)aom_malloc(sizeof (*input_data) * in_size *
cnn_config->layer_config[layer].in_channels);
float *temp_ptr = input_data;
ASSERT_NE(temp_ptr, nullptr);
for (int i = 0 ; i < cnn_config->layer_config[layer].in_channels; ++i) {
input[i] = temp_ptr;
for (int j = 0 ; j < in_size; j++) {
*(temp_ptr++) = ((float )rng_.Rand31() - (1 << 30 )) / (1 u << 31 );
}
}
float *out_data_ref = (float *)aom_calloc(
sizeof (*out_data_ref),
out_size * cnn_config->layer_config[layer].out_channels);
ASSERT_NE(out_data_ref, nullptr);
float *out_data_mod = (float *)aom_calloc(
sizeof (*out_data_mod),
out_size * cnn_config->layer_config[layer].out_channels);
ASSERT_NE(out_data_mod, nullptr);
float *temp_ptr1 = out_data_ref;
float *temp_ptr2 = out_data_mod;
for (int i = 0 ; i < cnn_config->layer_config[layer].out_channels; ++i) {
output_ref[i] = temp_ptr1;
output_mod[i] = temp_ptr2;
temp_ptr1 += out_size;
temp_ptr2 += out_size;
}
RunCNNConvolveTest(input, in_width, in_height, out_size,
&cnn_config->layer_config[layer], 0 , 1 , run_times,
layer, output_ref, output_mod, out_width);
// Set current layer output width and height as next layer input width and
// height.
in_width = out_width;
in_height = out_height;
aom_free(input_data);
aom_free(out_data_ref);
aom_free(out_data_mod);
}
}
void RunCNNConvolveTest(float **input, int in_width, int in_height,
int out_size, const CNN_LAYER_CONFIG *layer_config,
int start_idx, int step, int run_times, int layer,
float **output_ref, float **output_mod,
int out_stride) {
const int cstep = layer_config->in_channels * layer_config->out_channels;
const int channel_step = AOMMAX(step, 1 );
aom_usec_timer timer;
aom_usec_timer_start(&timer);
for (int i = 0 ; i < run_times; ++i) {
params_.ref_func((const float **)input, in_width, in_height, in_width,
layer_config, output_ref, out_stride, start_idx, cstep,
channel_step);
}
aom_usec_timer_mark(&timer);
const double time1 = static_cast <double >(aom_usec_timer_elapsed(&timer));
aom_usec_timer_start(&timer);
for (int i = 0 ; i < run_times; ++i) {
params_.tst_func((const float **)input, in_width, in_height, in_width,
layer_config, output_mod, out_stride, start_idx, cstep,
channel_step);
}
aom_usec_timer_mark(&timer);
const double time2 = static_cast <double >(aom_usec_timer_elapsed(&timer));
if (run_times > 1 ) {
printf("layer : %d \n" , layer);
printf("%7.2f/%7.2fns (%3.2f)\n" , time1, time2, time1 / time2);
} else {
for (int channel = 0 ; channel < layer_config->out_channels; ++channel) {
const float *buf_ref = output_ref[channel];
const float *buf_mod = output_mod[channel];
for (int i = 0 ; i < out_size; ++i) {
if (buf_ref[i] < CNN_CONVOLVE_PIXELWISE_FLOAT_TOL) {
ASSERT_LE(buf_ref[i], CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
<< "Reference output was near-zero, test output was not ("
<< buf_mod[i] << ")" ;
} else {
const float error = buf_ref[i] - buf_mod[i];
const float relative_error = fabsf(error / buf_ref[i]);
ASSERT_LE(relative_error, CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
<< " channel " << channel << " pixel " << i << ": "
<< buf_ref[i] << "/" << buf_mod[i] << std::endl;
}
}
}
}
}
private :
CNNConvolveTestFuncs params_;
libaom_test::ACMRandom rng_;
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CNNConvolveTest);
TEST_P(CNNConvolveTest, CheckOutput) { RunCNNConvolveSetup(1 ); }
TEST_P(CNNConvolveTest, DISABLED_Speed) { RunCNNConvolveSetup(100000 ); }
#if HAVE_AVX2 && !CONFIG_EXCLUDE_SIMD_MISMATCH
INSTANTIATE_TEST_SUITE_P(AVX2, CNNConvolveTest,
::testing::Values(CNNConvolveTestFuncs(
&av1_cnn_convolve_no_maxpool_padding_valid_c,
&av1_cnn_convolve_no_maxpool_padding_valid_avx2)));
#endif
#if HAVE_NEON
INSTANTIATE_TEST_SUITE_P(NEON, CNNConvolveTest,
::testing::Values(CNNConvolveTestFuncs(
&av1_cnn_convolve_no_maxpool_padding_valid_c,
&av1_cnn_convolve_no_maxpool_padding_valid_neon)));
#endif
} // namespace
Messung V0.5 in Prozent C=96 H=85 G=90