/*
* Copyright (c) 2018, Alliance for Open Media. All rights reserved.
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float .h>
#include <string.h>
#include "tools/txfm_analyzer/txfm_graph.h"
typedef enum CODE_TYPE {
CODE_TYPE_C,
CODE_TYPE_SSE2,
CODE_TYPE_SSE4_1
} CODE_TYPE;
int get_cos_idx(double value, int mod) {
return round(acos(fabs(value)) / PI * mod);
}
char *cos_text_arr(double value, int mod, char *text, int size) {
int num = get_cos_idx(value, mod);
if (value < 0 ) {
snprintf(text, size, "-cospi[%2d]" , num);
} else {
snprintf(text, size, " cospi[%2d]" , num);
}
if (num == 0 )
printf("v: %f -> %d/%d v==-1 is %d\n" , value, num, mod, value == -1 );
return text;
}
char *cos_text_sse2(double w0, double w1, int mod, char *text, int size) {
int idx0 = get_cos_idx(w0, mod);
int idx1 = get_cos_idx(w1, mod);
char p[] = "p" ;
char n[] = "m" ;
char *sgn0 = w0 < 0 ? n : p;
char *sgn1 = w1 < 0 ? n : p;
snprintf(text, size, "cospi_%s%02d_%s%02d" , sgn0, idx0, sgn1, idx1);
return text;
}
char *cos_text_sse4_1(double w, int mod, char *text, int size) {
int idx = get_cos_idx(w, mod);
char p[] = "p" ;
char n[] = "m" ;
char *sgn = w < 0 ? n : p;
snprintf(text, size, "cospi_%s%02d" , sgn, idx);
return text;
}
void node_to_code_c(Node *node, const char *buf0, const char *buf1) {
int cnt = 0 ;
for (int i = 0 ; i < 2 ; i++) {
if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0 ) cnt++;
}
if (cnt == 2 ) {
int cnt2 = 0 ;
printf(" %s[%d] =" , buf1, node->nodeIdx);
for (int i = 0 ; i < 2 ; i++) {
if (fabs(node->inWeight[i]) == 1 ) {
cnt2++;
}
}
if (cnt2 == 2 ) {
printf(" apply_value(" );
}
int cnt1 = 0 ;
for (int i = 0 ; i < 2 ; i++) {
if (node->inWeight[i] == 1 ) {
if (cnt1 > 0 )
printf(" + %s[%d]" , buf0, node->inNodeIdx[i]);
else
printf(" %s[%d]" , buf0, node->inNodeIdx[i]);
cnt1++;
} else if (node->inWeight[i] == -1 ) {
if (cnt1 > 0 )
printf(" - %s[%d]" , buf0, node->inNodeIdx[i]);
else
printf("-%s[%d]" , buf0, node->inNodeIdx[i]);
cnt1++;
}
}
if (cnt2 == 2 ) {
printf(", stage_range[stage])" );
}
printf(";\n" );
} else {
char w0[100 ];
char w1[100 ];
printf(
" %s[%d] = half_btf(%s, %s[%d], %s, %s[%d], "
"cos_bit);\n" ,
buf1, node->nodeIdx, cos_text_arr(node->inWeight[0 ], COS_MOD, w0, 100 ),
buf0, node->inNodeIdx[0 ],
cos_text_arr(node->inWeight[1 ], COS_MOD, w1, 100 ), buf0,
node->inNodeIdx[1 ]);
}
}
void gen_code_c(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
char *fun_name = new char [100 ];
get_fun_name(fun_name, 100 , type, node_num);
printf("\n" );
printf(
"void av1_%s(const int32_t *input, int32_t *output, int8_t cos_bit, "
"const int8_t* stage_range) "
"{\n" ,
fun_name);
printf(" assert(output != input);\n" );
printf(" const int32_t size = %d;\n" , node_num);
printf(" const int32_t *cospi = cospi_arr(cos_bit);\n" );
printf("\n" );
printf(" int32_t stage = 0;\n" );
printf(" int32_t *bf0, *bf1;\n" );
printf(" int32_t step[%d];\n" , node_num);
const char *buf0 = "bf0" ;
const char *buf1 = "bf1" ;
const char *input = "input" ;
int si = 0 ;
printf("\n" );
printf(" // stage %d;\n", si);
printf(" apply_range(stage, input, %s, size, stage_range[stage]);\n" , input);
si = 1 ;
printf("\n" );
printf(" // stage %d;\n", si);
printf(" stage++;\n" );
if (si % 2 == (stage_num - 1 ) % 2 ) {
printf(" %s = output;\n" , buf1);
} else {
printf(" %s = step;\n" , buf1);
}
for (int ni = 0 ; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
node_to_code_c(node + idx, input, buf1);
}
printf(" range_check_buf(stage, input, bf1, size, stage_range[stage]);\n" );
for (int si = 2 ; si < stage_num; si++) {
printf("\n" );
printf(" // stage %d\n", si);
printf(" stage++;\n" );
if (si % 2 == (stage_num - 1 ) % 2 ) {
printf(" %s = step;\n" , buf0);
printf(" %s = output;\n" , buf1);
} else {
printf(" %s = output;\n" , buf0);
printf(" %s = step;\n" , buf1);
}
// computation code
for (int ni = 0 ; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
node_to_code_c(node + idx, buf0, buf1);
}
if (si != stage_num - 1 ) {
printf(
" range_check_buf(stage, input, bf1, size, stage_range[stage]);\n" );
}
}
printf(" apply_range(stage, input, output, size, stage_range[stage]);\n" );
printf("}\n" );
}
void single_node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
printf(" %s[%2d] =" , buf1, node->nodeIdx);
if (node->inWeight[0 ] == 1 && node->inWeight[1 ] == 1 ) {
printf(" _mm_adds_epi16(%s[%d], %s[%d])" , buf0, node->inNodeIdx[0 ], buf0,
node->inNodeIdx[1 ]);
} else if (node->inWeight[0 ] == 1 && node->inWeight[1 ] == -1 ) {
printf(" _mm_subs_epi16(%s[%d], %s[%d])" , buf0, node->inNodeIdx[0 ], buf0,
node->inNodeIdx[1 ]);
} else if (node->inWeight[0 ] == -1 && node->inWeight[1 ] == 1 ) {
printf(" _mm_subs_epi16(%s[%d], %s[%d])" , buf0, node->inNodeIdx[1 ], buf0,
node->inNodeIdx[0 ]);
} else if (node->inWeight[0 ] == 1 && node->inWeight[1 ] == 0 ) {
printf(" %s[%d]" , buf0, node->inNodeIdx[0 ]);
} else if (node->inWeight[0 ] == 0 && node->inWeight[1 ] == 1 ) {
printf(" %s[%d]" , buf0, node->inNodeIdx[1 ]);
} else if (node->inWeight[0 ] == -1 && node->inWeight[1 ] == 0 ) {
printf(" _mm_subs_epi16(__zero, %s[%d])" , buf0, node->inNodeIdx[0 ]);
} else if (node->inWeight[0 ] == 0 && node->inWeight[1 ] == -1 ) {
printf(" _mm_subs_epi16(__zero, %s[%d])" , buf0, node->inNodeIdx[1 ]);
}
printf(";\n" );
}
void pair_node_to_code_sse2(Node *node, Node *partnerNode, const char *buf0,
const char *buf1) {
char temp0[100 ];
char temp1[100 ];
// btf_16_sse2_type0(w0, w1, in0, in1, out0, out1)
if (node->inNodeIdx[0 ] != partnerNode->inNodeIdx[0 ])
printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n" ,
cos_text_sse2(node->inWeight[0 ], node->inWeight[1 ], COS_MOD, temp0,
100 ),
cos_text_sse2(partnerNode->inWeight[1 ], partnerNode->inWeight[0 ],
COS_MOD, temp1, 100 ),
buf0, node->inNodeIdx[0 ], buf0, node->inNodeIdx[1 ], buf1,
node->nodeIdx, buf1, partnerNode->nodeIdx);
else
printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n" ,
cos_text_sse2(node->inWeight[0 ], node->inWeight[1 ], COS_MOD, temp0,
100 ),
cos_text_sse2(partnerNode->inWeight[0 ], partnerNode->inWeight[1 ],
COS_MOD, temp1, 100 ),
buf0, node->inNodeIdx[0 ], buf0, node->inNodeIdx[1 ], buf1,
node->nodeIdx, buf1, partnerNode->nodeIdx);
}
Node *get_partner_node(Node *node) {
int diff = node->inNode[1 ]->nodeIdx - node->nodeIdx;
return node + diff;
}
void node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
int cnt = 0 ;
int cnt1 = 0 ;
if (node->visited == 0 ) {
node->visited = 1 ;
for (int i = 0 ; i < 2 ; i++) {
if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0 ) cnt++;
if (fabs(node->inWeight[i]) == 1 ) cnt1++;
}
if (cnt == 2 ) {
if (cnt1 == 2 ) {
// has a partner
Node *partnerNode = get_partner_node(node);
partnerNode->visited = 1 ;
single_node_to_code_sse2(node, buf0, buf1);
single_node_to_code_sse2(partnerNode, buf0, buf1);
} else {
single_node_to_code_sse2(node, buf0, buf1);
}
} else {
Node *partnerNode = get_partner_node(node);
partnerNode->visited = 1 ;
pair_node_to_code_sse2(node, partnerNode, buf0, buf1);
}
}
}
void gen_cospi_list_sse2(Node *node, int stage_num, int node_num) {
int visited[65 ][65 ][2 ][2 ];
memset(visited, 0 , sizeof (visited));
char text[100 ];
char text1[100 ];
char text2[100 ];
int size = 100 ;
printf("\n" );
for (int si = 1 ; si < stage_num; si++) {
for (int ni = 0 ; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
int cnt = 0 ;
Node *node0 = node + idx;
if (node0->visited == 0 ) {
node0->visited = 1 ;
for (int i = 0 ; i < 2 ; i++) {
if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0 )
cnt++;
}
if (cnt != 2 ) {
{
double w0 = node0->inWeight[0 ];
double w1 = node0->inWeight[1 ];
int idx0 = get_cos_idx(w0, COS_MOD);
int idx1 = get_cos_idx(w1, COS_MOD);
int sgn0 = w0 < 0 ? 1 : 0 ;
int sgn1 = w1 < 0 ? 1 : 0 ;
if (!visited[idx0][idx1][sgn0][sgn1]) {
visited[idx0][idx1][sgn0][sgn1] = 1 ;
printf(" __m128i %s = pair_set_epi16(%s, %s);\n" ,
cos_text_sse2(w0, w1, COS_MOD, text, size),
cos_text_arr(w0, COS_MOD, text1, size),
cos_text_arr(w1, COS_MOD, text2, size));
}
}
Node *node1 = get_partner_node(node0);
node1->visited = 1 ;
if (node1->inNode[0 ]->nodeIdx != node0->inNode[0 ]->nodeIdx) {
double w0 = node1->inWeight[0 ];
double w1 = node1->inWeight[1 ];
int idx0 = get_cos_idx(w0, COS_MOD);
int idx1 = get_cos_idx(w1, COS_MOD);
int sgn0 = w0 < 0 ? 1 : 0 ;
int sgn1 = w1 < 0 ? 1 : 0 ;
if (!visited[idx1][idx0][sgn1][sgn0]) {
visited[idx1][idx0][sgn1][sgn0] = 1 ;
printf(" __m128i %s = pair_set_epi16(%s, %s);\n" ,
cos_text_sse2(w1, w0, COS_MOD, text, size),
cos_text_arr(w1, COS_MOD, text1, size),
cos_text_arr(w0, COS_MOD, text2, size));
}
} else {
double w0 = node1->inWeight[0 ];
double w1 = node1->inWeight[1 ];
int idx0 = get_cos_idx(w0, COS_MOD);
int idx1 = get_cos_idx(w1, COS_MOD);
int sgn0 = w0 < 0 ? 1 : 0 ;
int sgn1 = w1 < 0 ? 1 : 0 ;
if (!visited[idx0][idx1][sgn0][sgn1]) {
visited[idx0][idx1][sgn0][sgn1] = 1 ;
printf(" __m128i %s = pair_set_epi16(%s, %s);\n" ,
cos_text_sse2(w0, w1, COS_MOD, text, size),
cos_text_arr(w0, COS_MOD, text1, size),
cos_text_arr(w1, COS_MOD, text2, size));
}
}
}
}
}
}
}
void gen_code_sse2(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
char *fun_name = new char [100 ];
get_fun_name(fun_name, 100 , type, node_num);
printf("\n" );
printf(
"void %s_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) "
"{\n" ,
fun_name);
printf(" const int32_t* cospi = cospi_arr(cos_bit);\n" );
printf(" const __m128i __zero = _mm_setzero_si128();\n" );
printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n" );
graph_reset_visited(node, stage_num, node_num);
gen_cospi_list_sse2(node, stage_num, node_num);
graph_reset_visited(node, stage_num, node_num);
for (int si = 1 ; si < stage_num; si++) {
char in[100 ];
char out[100 ];
printf("\n" );
printf(" // stage %d\n", si);
if (si == 1 )
snprintf(in, 100 , "%s" , "input" );
else
snprintf(in, 100 , "x%d" , si - 1 );
if (si == stage_num - 1 ) {
snprintf(out, 100 , "%s" , "output" );
} else {
snprintf(out, 100 , "x%d" , si);
printf(" __m128i %s[%d];\n" , out, node_num);
}
// computation code
for (int ni = 0 ; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
node_to_code_sse2(node + idx, in, out);
}
}
printf("}\n" );
}
void gen_cospi_list_sse4_1(Node *node, int stage_num, int node_num) {
int visited[65 ][2 ];
memset(visited, 0 , sizeof (visited));
char text[100 ];
char text1[100 ];
int size = 100 ;
printf("\n" );
for (int si = 1 ; si < stage_num; si++) {
for (int ni = 0 ; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
Node *node0 = node + idx;
if (node0->visited == 0 ) {
int cnt = 0 ;
node0->visited = 1 ;
for (int i = 0 ; i < 2 ; i++) {
if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0 )
cnt++;
}
if (cnt != 2 ) {
for (int i = 0 ; i < 2 ; i++) {
if (fabs(node0->inWeight[i]) != 1 &&
fabs(node0->inWeight[i]) != 0 ) {
double w = node0->inWeight[i];
int idx = get_cos_idx(w, COS_MOD);
int sgn = w < 0 ? 1 : 0 ;
if (!visited[idx][sgn]) {
visited[idx][sgn] = 1 ;
printf(" __m128i %s = _mm_set1_epi32(%s);\n" ,
cos_text_sse4_1(w, COS_MOD, text, size),
cos_text_arr(w, COS_MOD, text1, size));
}
}
}
Node *node1 = get_partner_node(node0);
node1->visited = 1 ;
}
}
}
}
}
void single_node_to_code_sse4_1(Node *node, const char *buf0,
const char *buf1) {
printf(" %s[%2d] =" , buf1, node->nodeIdx);
if (node->inWeight[0 ] == 1 && node->inWeight[1 ] == 1 ) {
printf(" _mm_add_epi32(%s[%d], %s[%d])" , buf0, node->inNodeIdx[0 ], buf0,
node->inNodeIdx[1 ]);
} else if (node->inWeight[0 ] == 1 && node->inWeight[1 ] == -1 ) {
printf(" _mm_sub_epi32(%s[%d], %s[%d])" , buf0, node->inNodeIdx[0 ], buf0,
node->inNodeIdx[1 ]);
} else if (node->inWeight[0 ] == -1 && node->inWeight[1 ] == 1 ) {
printf(" _mm_sub_epi32(%s[%d], %s[%d])" , buf0, node->inNodeIdx[1 ], buf0,
node->inNodeIdx[0 ]);
} else if (node->inWeight[0 ] == 1 && node->inWeight[1 ] == 0 ) {
printf(" %s[%d]" , buf0, node->inNodeIdx[0 ]);
} else if (node->inWeight[0 ] == 0 && node->inWeight[1 ] == 1 ) {
printf(" %s[%d]" , buf0, node->inNodeIdx[1 ]);
} else if (node->inWeight[0 ] == -1 && node->inWeight[1 ] == 0 ) {
printf(" _mm_sub_epi32(__zero, %s[%d])" , buf0, node->inNodeIdx[0 ]);
} else if (node->inWeight[0 ] == 0 && node->inWeight[1 ] == -1 ) {
printf(" _mm_sub_epi32(__zero, %s[%d])" , buf0, node->inNodeIdx[1 ]);
}
printf(";\n" );
}
void pair_node_to_code_sse4_1(Node *node, Node *partnerNode, const char *buf0,
const char *buf1) {
char temp0[100 ];
char temp1[100 ];
if (node->inWeight[0 ] * partnerNode->inWeight[0 ] < 0 ) {
/* type0
* cos sin
* sin -cos
*/
// btf_32_sse2_type0(w0, w1, in0, in1, out0, out1)
// out0 = w0*in0 + w1*in1
// out1 = -w0*in1 + w1*in0
printf(
" btf_32_type0_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
"__rounding, cos_bit);\n" ,
cos_text_sse4_1(node->inWeight[0 ], COS_MOD, temp0, 100 ),
cos_text_sse4_1(node->inWeight[1 ], COS_MOD, temp1, 100 ), buf0,
node->inNodeIdx[0 ], buf0, node->inNodeIdx[1 ], buf1, node->nodeIdx, buf1,
partnerNode->nodeIdx);
} else {
/* type1
* cos sin
* -sin cos
*/
// btf_32_sse2_type1(w0, w1, in0, in1, out0, out1)
// out0 = w0*in0 + w1*in1
// out1 = w0*in1 - w1*in0
printf(
" btf_32_type1_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
"__rounding, cos_bit);\n" ,
cos_text_sse4_1(node->inWeight[0 ], COS_MOD, temp0, 100 ),
cos_text_sse4_1(node->inWeight[1 ], COS_MOD, temp1, 100 ), buf0,
node->inNodeIdx[0 ], buf0, node->inNodeIdx[1 ], buf1, node->nodeIdx, buf1,
partnerNode->nodeIdx);
}
}
void node_to_code_sse4_1(Node *node, const char *buf0, const char *buf1) {
int cnt = 0 ;
int cnt1 = 0 ;
if (node->visited == 0 ) {
node->visited = 1 ;
for (int i = 0 ; i < 2 ; i++) {
if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0 ) cnt++;
if (fabs(node->inWeight[i]) == 1 ) cnt1++;
}
if (cnt == 2 ) {
if (cnt1 == 2 ) {
// has a partner
Node *partnerNode = get_partner_node(node);
partnerNode->visited = 1 ;
single_node_to_code_sse4_1(node, buf0, buf1);
single_node_to_code_sse4_1(partnerNode, buf0, buf1);
} else {
single_node_to_code_sse2(node, buf0, buf1);
}
} else {
Node *partnerNode = get_partner_node(node);
partnerNode->visited = 1 ;
pair_node_to_code_sse4_1(node, partnerNode, buf0, buf1);
}
}
}
void gen_code_sse4_1(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
char *fun_name = new char [100 ];
get_fun_name(fun_name, 100 , type, node_num);
printf("\n" );
printf(
"void %s_sse4_1(const __m128i *input, __m128i *output, int8_t cos_bit) "
"{\n" ,
fun_name);
printf(" const int32_t* cospi = cospi_arr(cos_bit);\n" );
printf(" const __m128i __zero = _mm_setzero_si128();\n" );
printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n" );
graph_reset_visited(node, stage_num, node_num);
gen_cospi_list_sse4_1(node, stage_num, node_num);
graph_reset_visited(node, stage_num, node_num);
for (int si = 1 ; si < stage_num; si++) {
char in[100 ];
char out[100 ];
printf("\n" );
printf(" // stage %d\n", si);
if (si == 1 )
snprintf(in, 100 , "%s" , "input" );
else
snprintf(in, 100 , "x%d" , si - 1 );
if (si == stage_num - 1 ) {
snprintf(out, 100 , "%s" , "output" );
} else {
snprintf(out, 100 , "x%d" , si);
printf(" __m128i %s[%d];\n" , out, node_num);
}
// computation code
for (int ni = 0 ; ni < node_num; ni++) {
int idx = get_idx(si, ni, node_num);
node_to_code_sse4_1(node + idx, in, out);
}
}
printf("}\n" );
}
void gen_hybrid_code(CODE_TYPE code_type, TYPE_TXFM txfm_type, int node_num) {
int stage_num = get_hybrid_stage_num(txfm_type, node_num);
Node *node = new Node[node_num * stage_num];
init_graph(node, stage_num, node_num);
gen_hybrid_graph_1d(node, stage_num, node_num, 0 , 0 , node_num, txfm_type);
switch (code_type) {
case CODE_TYPE_C: gen_code_c(node, stage_num, node_num, txfm_type); break ;
case CODE_TYPE_SSE2:
gen_code_sse2(node, stage_num, node_num, txfm_type);
break ;
case CODE_TYPE_SSE4_1:
gen_code_sse4_1(node, stage_num, node_num, txfm_type);
break ;
}
delete [] node;
}
int main(int argc, char **argv) {
CODE_TYPE code_type = CODE_TYPE_SSE4_1;
for (int txfm_type = TYPE_DCT; txfm_type < TYPE_LAST; txfm_type++) {
for (int node_num = 4 ; node_num <= 64 ; node_num *= 2 ) {
gen_hybrid_code(code_type, (TYPE_TXFM)txfm_type, node_num);
}
}
return 0 ;
}
Messung V0.5 in Prozent C=100 H=87 G=93
¤ Dauer der Verarbeitung: 0.11 Sekunden
(vorverarbeitet am 2026-06-06)
¤
*© Formatika GbR, Deutschland