Commit 76ebc5ad authored by Raymond Knopp's avatar Raymond Knopp

addition of ARM NEON intrinsics

parent 88d5bf42
...@@ -254,6 +254,18 @@ void build_decoder_tree(t_nrPolar_params *pp) { ...@@ -254,6 +254,18 @@ void build_decoder_tree(t_nrPolar_params *pp) {
} }
#if defined(__arm__) || defined(__aarch64__)
// translate 1-1 SIMD functions from SSE to NEON
#define __m128i int16x8_t
#define __m64 int8x8_t
#define _mm_abs_epi16(a) vabsq_s16(a)
#define _mm_min_epi16(a,b) vminq_s16(a,b)
#define _mm_subs_epi16(a,b) vsubq_s16(a,b)
#define _mm_abs_pi16(a) vabs_s16(a)
#define _mm_min_pi16(a,b) vmin_s16(a,b)
#define _mm_subs_pi16(a,b) vsub_s16(a,b)
#endif
void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) { void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
int16_t *alpha_v=node->alpha; int16_t *alpha_v=node->alpha;
int16_t *alpha_l=node->left->alpha; int16_t *alpha_l=node->left->alpha;
...@@ -270,7 +282,6 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) { ...@@ -270,7 +282,6 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
if (node->left->all_frozen == 0) { if (node->left->all_frozen == 0) {
#if defined(__AVX2__) #if defined(__AVX2__)
int avx2mod = (node->Nv/2)&15; int avx2mod = (node->Nv/2)&15;
if (avx2mod == 0) { if (avx2mod == 0) {
...@@ -284,14 +295,7 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) { ...@@ -284,14 +295,7 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
absa256 =_mm256_abs_epi16(a256); absa256 =_mm256_abs_epi16(a256);
absb256 =_mm256_abs_epi16(b256); absb256 =_mm256_abs_epi16(b256);
minabs256 =_mm256_min_epi16(absa256,absb256); minabs256 =_mm256_min_epi16(absa256,absb256);
((__m256i*)alpha_l)[i] =_mm256_sign_epi16(minabs256,_mm256_xor_si256(a256,b256)); ((__m256i*)alpha_l)[i] =_mm256_sign_epi16(minabs256,_mm256_sign_epi16(a256,b256));
/* for (int j=0;j<16;j++) printf("alphal[%d] %d (%d,%d,%d)\n",
(16*i) + j,
alpha_l[(16*i)+j],
((int16_t*)&minabs256)[j],
alpha_v[(16*i)+j],
alpha_v[(16*i)+j+(node->Nv/2)]);
*/
} }
} }
else if (avx2mod == 8) { else if (avx2mod == 8) {
...@@ -301,7 +305,7 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) { ...@@ -301,7 +305,7 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
absa128 =_mm_abs_epi16(a128); absa128 =_mm_abs_epi16(a128);
absb128 =_mm_abs_epi16(b128); absb128 =_mm_abs_epi16(b128);
minabs128 =_mm_min_epi16(absa128,absb128); minabs128 =_mm_min_epi16(absa128,absb128);
*((__m128i*)alpha_l) =_mm_sign_epi16(minabs128,_mm_xor_si128(a128,b128)); *((__m128i*)alpha_l) =_mm_sign_epi16(minabs128,_mm_sign_epi16(a128,b128));
} }
else if (avx2mod == 4) { else if (avx2mod == 4) {
__m64 a64,b64,absa64,absb64,minabs64; __m64 a64,b64,absa64,absb64,minabs64;
...@@ -310,11 +314,56 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) { ...@@ -310,11 +314,56 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
absa64 =_mm_abs_pi16(a64); absa64 =_mm_abs_pi16(a64);
absb64 =_mm_abs_pi16(b64); absb64 =_mm_abs_pi16(b64);
minabs64 =_mm_min_pi16(absa64,absb64); minabs64 =_mm_min_pi16(absa64,absb64);
*((__m64*)alpha_l) =_mm_sign_pi16(minabs64,_mm_xor_si64(a64,b64)); *((__m64*)alpha_l) =_mm_sign_pi16(minabs64,_mm_sign_pi16(a64,b64));
} }
else else
#else
int sse4mod = (node->Nv/2)&7;
int sse4len = node->Nv/2/8;
#if defined(__arm__) || defined(__aarch64__)
int16x8_t signatimesb,comp1,comp2,negminabs128;
int16x8_t zero=vdupq_n_s16(0);
#endif #endif
{
if (sse4mod == 0) {
for (int i=0;i<sse4len;i++) {
__m128i a128,b128,absa128,absb128,minabs128;
int sse4len = node->Nv/2/8;
a128 =*((__m128i*)alpha_v);
b128 =((__m128i*)alpha_v)[1];
absa128 =_mm_abs_epi16(a128);
absb128 =_mm_abs_epi16(b128);
minabs128 =_mm_min_epi16(absa128,absb128);
#if defined(__arm__) || defined(__aarch64__)
// unfortunately no direct equivalent to _mm_sign_epi16
signatimesb=vxorrq_s16(a128,b128);
comp1=vcltq_s16(signatimesb,zero);
comp2=vcgeq_s16(signatimesb,zero);
negminabs128=vnegq_s16(minabs128);
*((__m128i*)alpha_l) =vorrq_s16(vandq_s16(minabs128,comp0),vandq_s16(negminabs128,comp1));
#else
*((__m128i*)alpha_l) =_mm_sign_epi16(minabs128,_mm_sign_epi16(a128,b128));
#endif
}
}
else if (sse4mod == 4) {
__m64 a64,b64,absa64,absb64,minabs64;
a64 =*((__m64*)alpha_v);
b64 =((__m64*)alpha_v)[1];
absa64 =_mm_abs_pi16(a64);
absb64 =_mm_abs_pi16(b64);
minabs64 =_mm_min_pi16(absa64,absb64);
#if defined(__arm__) || defined(__aarch64__)
AssertFatal(1==0,"Need to do this still for ARM\n");
#else
*((__m64*)alpha_l) =_mm_sign_pi16(minabs64,_mm_sign_epi16(a64,b64));
#endif
}
else
#endif
{ // equvalent scalar code to above, activated only on non x86/ARM architectures
for (int i=0;i<node->Nv/2;i++) { for (int i=0;i<node->Nv/2;i++) {
a=alpha_v[i]; a=alpha_v[i];
b=alpha_v[i+(node->Nv/2)]; b=alpha_v[i+(node->Nv/2)];
...@@ -367,9 +416,34 @@ void applyGtoright(t_nrPolar_params *pp,decoder_node_t *node) { ...@@ -367,9 +416,34 @@ void applyGtoright(t_nrPolar_params *pp,decoder_node_t *node) {
else if (avx2mod == 8) { else if (avx2mod == 8) {
((__m128i *)alpha_r)[0] = _mm_subs_epi16(((__m128i *)alpha_v)[1],_mm_sign_epi16(((__m128i *)alpha_v)[0],((__m128i *)betal)[0])); ((__m128i *)alpha_r)[0] = _mm_subs_epi16(((__m128i *)alpha_v)[1],_mm_sign_epi16(((__m128i *)alpha_v)[0],((__m128i *)betal)[0]));
} }
else if (avx2mod == 4) {
((__m64 *)alpha_r)[0] = _mm_subs_pi16(((__m64 *)alpha_v)[1],_mm_sign_pi16(((__m64 *)alpha_v)[0],((__m64 *)betal)[0]));
}
else else
#else
int sse4mod = (node->Nv/2)&7;
if (sse4mod == 0) {
int sse4len = node->Nv/2/8;
for (int i=0;i<sse4len;i++) {
#if defined(__arm__) || defined(__aarch64__)
((int16x8_t *)alpha_r)[0] = vsubq_s16(((int16x8_t *)alpha_v)[1],vmulq_epi16(((int16x8_t *)alpha_v)[0],((int16x8_t *)betal)[0]));
#else
((__m128i *)alpha_r)[0] = _mm_subs_epi16(((__m128i *)alpha_v)[1],_mm_sign_epi16(((__m128i *)alpha_v)[0],((__m128i *)betal)[0]));
#endif #endif
{ }
}
else if (sse4mod == 4) {
#if defined(__arm__) || defined(__aarch64__)
((int16x4_t *)alpha_r)[0] = vsub_s16(((int16x4_t *)alpha_v)[1],vmul_epi16(((int16x4_t *)alpha_v)[0],((int16x4_t *)betal)[0]));
#else
((__m64 *)alpha_r)[0] = _mm_subs_pi16(((__m64 *)alpha_v)[1],_mm_sign_pi16(((__64 *)alpha_v)[0],((__m64 *)betal)[0]));
#endif
}
else
#endif
{// equvalent scalar code to above, activated only on non x86/ARM architectures
for (int i=0;i<node->Nv/2;i++) { for (int i=0;i<node->Nv/2;i++) {
alpha_r[i] = alpha_v[i+(node->Nv/2)] - (betal[i]*alpha_v[i]); alpha_r[i] = alpha_v[i+(node->Nv/2)] - (betal[i]*alpha_v[i]);
} }
...@@ -385,10 +459,10 @@ void applyGtoright(t_nrPolar_params *pp,decoder_node_t *node) { ...@@ -385,10 +459,10 @@ void applyGtoright(t_nrPolar_params *pp,decoder_node_t *node) {
} }
int16_t minus1[16] = {-1,-1,-1,-1, int16_t all1[16] = {1,1,1,1,
-1,-1,-1,-1, 1,1,1,1,
-1,-1,-1,-1, 1,1,1,1,
-1,-1,-1,-1}; 1,1,1,1};
void computeBeta(t_nrPolar_params *pp,decoder_node_t *node) { void computeBeta(t_nrPolar_params *pp,decoder_node_t *node) {
...@@ -401,27 +475,37 @@ void computeBeta(t_nrPolar_params *pp,decoder_node_t *node) { ...@@ -401,27 +475,37 @@ void computeBeta(t_nrPolar_params *pp,decoder_node_t *node) {
if (node->left->all_frozen==0) { // if left node is not aggregation of frozen bits if (node->left->all_frozen==0) { // if left node is not aggregation of frozen bits
#if defined(__AVX2__) #if defined(__AVX2__)
int avx2mod = (node->Nv/2)&15; int avx2mod = (node->Nv/2)&15;
register __m256i allones=*((__m256i*)all1);
if (avx2mod == 0) { if (avx2mod == 0) {
int avx2len = node->Nv/2/16; int avx2len = node->Nv/2/16;
for (int i=0;i<avx2len;i++) { for (int i=0;i<avx2len;i++) {
((__m256i*)betav)[i] = _mm256_sign_epi16(((__m256i*)betar)[i], ((__m256i*)betav)[i] = _mm256_or_si256(_mm256_cmpeq_epi16(((__m256i*)betar)[i],
((__m256i*)betal)[i]); ((__m256i*)betal)[i]),allones);
((__m256i*)betav)[i] = _mm256_sign_epi16(((__m256i*)betav)[i],
((__m256i*)minus1)[0]);
} }
} }
else if (avx2mod == 8) { else if (avx2mod == 8) {
((__m128i*)betav)[0] = _mm_sign_epi16(((__m128i*)betar)[0], ((__m128i*)betav)[0] = _mm_or_si128(_mm_cmpeq_epi16(((__m128i*)betar)[0],
((__m128i*)betal)[0]); ((__m128i*)betal)[0]),*((__m128i*)all1));
((__m128i*)betav)[0] = _mm_sign_epi16(((__m128i*)betav)[0],
((__m128i*)minus1)[0]);
} }
else if (avx2mod == 4) { else if (avx2mod == 4) {
((__m64*)betav)[0] = _mm_sign_pi16(((__m64*)betar)[0], ((__m64*)betav)[0] = _mm_or_si64(_mm_cmpeq_pi16(((__m64*)betar)[0],
((__m64*)betal)[0]); ((__m64*)betal)[0]),*((__m64*)all1));
((__m64*)betav)[0] = _mm_sign_pi16(((__m64*)betav)[0], }
((__m64*)minus1)[0]); else
#else
int avx2mod = (node->Nv/2)&15;
if (ssr4mod == 0) {
int ssr4len = node->Nv/2/8;
register __m128i allones=*((__m128i*)all1);
for (int i=0;i<sse4len;i++) {
((__m256i*)betav)[i] = _mm_or_si128(_mm_cmpeq_epi16(((__m128i*)betar)[i],
((__m128i*)betal)[i]),allones));
}
}
else if (sse4mod == 4) {
((__m64*)betav)[0] = _mm_or_si64(_mm_cmpeq_pi16(((__m64*)betar)[0],
((__m64*)betal)[0]),*((__m64*)all1));
} }
else else
#endif #endif
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment