Commit ac06f107 authored by Laurent THOMAS's avatar Laurent THOMAS

add a optimzed complex vector multiplication function

parent 220d851d
......@@ -1333,50 +1333,6 @@ void nr_a_sum_b(c16_t *input_x, c16_t *input_y, unsigned short nb_rb)
}
}
/* Zero Forcing Rx function: nr_a_mult_b()
* Compute the complex Multiplication c=a*b
*
* */
void nr_a_mult_b(c16_t *a, c16_t *b, c16_t *c, unsigned short nb_rb, unsigned char output_shift0)
{
//This function is used to compute complex multiplications
short nr_conjugate[8]__attribute__((aligned(16))) = {1,-1,1,-1,1,-1,1,-1};
unsigned short rb;
simde__m128i *a_128,*b_128, *c_128, mmtmpD0,mmtmpD1,mmtmpD2,mmtmpD3;
a_128 = (simde__m128i *)a;
b_128 = (simde__m128i *)b;
c_128 = (simde__m128i *)c;
for (rb=0; rb<3*nb_rb; rb++) {
// the real part
mmtmpD0 = simde_mm_sign_epi16(a_128[0],*(simde__m128i*)&nr_conjugate[0]);
mmtmpD0 = simde_mm_madd_epi16(mmtmpD0,b_128[0]); //Re: (a_re*b_re - a_im*b_im)
// the imag part
mmtmpD1 = simde_mm_shufflelo_epi16(a_128[0],SIMDE_MM_SHUFFLE(2,3,0,1));
mmtmpD1 = simde_mm_shufflehi_epi16(mmtmpD1,SIMDE_MM_SHUFFLE(2,3,0,1));
mmtmpD1 = simde_mm_madd_epi16(mmtmpD1,b_128[0]);//Im: (x_im*y_re + x_re*y_im)
mmtmpD0 = simde_mm_srai_epi32(mmtmpD0,output_shift0);
mmtmpD1 = simde_mm_srai_epi32(mmtmpD1,output_shift0);
mmtmpD2 = simde_mm_unpacklo_epi32(mmtmpD0,mmtmpD1);
mmtmpD3 = simde_mm_unpackhi_epi32(mmtmpD0,mmtmpD1);
c_128[0] = simde_mm_packs_epi32(mmtmpD2,mmtmpD3);
/*printf("\n Computing mult \n");
print_shorts("a:",(int16_t*)&a_128[0]);
print_shorts("b:",(int16_t*)&b_128[0]);
print_shorts("pack:",(int16_t*)&c_128[0]);*/
a_128+=1;
b_128+=1;
c_128+=1;
}
}
/* Zero Forcing Rx function: nr_element_sign()
* Compute b=sign*a
*
......@@ -1448,7 +1404,7 @@ static void nr_determin(int size,
nb_rb,
((rtx & 1) == 1 ? -1 : 1) * ((ctx & 1) == 1 ? -1 : 1) * sign,
shift0);
nr_a_mult_b(a44[ctx][rtx], outtemp, rtx == 0 ? ad_bc : outtemp1, nb_rb, shift0);
mult_complex_vectors(a44[ctx][rtx], outtemp, rtx == 0 ? ad_bc : outtemp1, sizeofArray(outtemp1), shift0);
if (rtx != 0)
nr_a_sum_b(ad_bc, outtemp1, nb_rb);
......@@ -1759,11 +1715,11 @@ static void nr_dlsch_mmse(uint32_t rx_size_symbol,
// printf("Computing r_%d c_%d\n",rtx,ctx);
// print_shorts(" H_h_H=",(int16_t*)&conjH_H_elements[ctx*nl+rtx][0][0]);
// print_shorts(" Inv_H_h_H=",(int16_t*)&inv_H_h_H[ctx*nl+rtx][0]);
nr_a_mult_b(inv_H_h_H[ctx][rtx],
(c16_t *)(rxdataF_comp[ctx][0] + symbol * rx_size_symbol),
outtemp,
nb_rb_0,
shift - (fp_flag == 1 ? 2 : 0));
mult_complex_vectors(inv_H_h_H[ctx][rtx],
(c16_t *)(rxdataF_comp[ctx][0] + symbol * rx_size_symbol),
outtemp,
sizeofArray(outtemp),
shift - (fp_flag == 1 ? 2 : 0));
nr_a_sum_b(rxdataF_zforcing[rtx], outtemp, nb_rb_0); // a = a + b
}
#ifdef DEBUG_DLSCH_DEMOD
......
......@@ -16,8 +16,15 @@ add_dependencies(tests test_log2_approx)
add_test(NAME test_log2_approx
COMMAND ./test_log2_approx --gtest_filter=-log2_approx.complete)
add_executable(dft_test test_dft.c ../dfts_load.c)
target_link_libraries(dft_test minimal_lib shlib_loader SIMU m)
add_dependencies(tests dft_test)
add_dependencies(dft_test dfts) # trigger build of dfts (shared lib for DFT)
add_test(NAME dft_test COMMAND ./dft_test)
add_executable(test_vector_op test_vector_op.cpp)
target_link_libraries(test_vector_op PRIVATE LOG minimal_lib)
add_dependencies(tests test_vector_op)
add_test(NAME test_vector_op
COMMAND test_vector_op)
......@@ -19,3 +19,52 @@
* contact@openairinterface.org
*/
#include <stdint.h>
#include <vector>
#include <algorithm>
#include <numeric>
extern "C" {
#include "openair1/PHY/TOOLS/tools_defs.h"
struct configmodule_interface_s;
struct configmodule_interface_s *uniqCfg = NULL;
void exit_function(const char *file, const char *function, const int line, const char *s, const int assert)
{
if (assert) {
abort();
} else {
exit(EXIT_SUCCESS);
}
}
}
#include <cstdio>
#include "common/utils/LOG/log.h"
#include "openair1/PHY/TOOLS/phy_test_tools.hpp"
int main()
{
const int shift = 15; // it should always be 15 to keep int16 in same range
for (int vector_size = 1237; vector_size < 1237 + 8; vector_size++) {
auto input1 = generate_random_c16(vector_size);
auto input2 = generate_random_c16(vector_size);
AlignedVector512<c16_t> output;
output.resize(vector_size);
mult_complex_vectors(input1.data(), input2.data(), output.data(), vector_size, shift);
for (int i = 0; i < vector_size; i++) {
c16_t res = c16mulShift(input1[i], input2[i], shift);
if (output[i].r != res.r || output[i].i != res.i) {
printf("Error at %d: (%d,%d) * (%d,%d) = (%d,%d) (should be (%d,%d))\n",
i,
input1[i].r,
input1[i].i,
input2[i].r,
input2[i].i,
output[i].r,
output[i].i,
res.r,
res.i);
return 1;
}
}
}
return 0;
}
......@@ -357,6 +357,74 @@ static __attribute__((always_inline)) inline void multadd_real_four_symbols_vect
simde_mm_storeu_si128((simd_q15_t *)y, y_128);
}
// Multiply two vectors of complex int16 and take the most significant bits (shift by 15 in normal case)
// works only with little endian storage (for big endian, modify the srai/ssli at the end)
static __attribute__((always_inline)) inline void mult_complex_vectors(const c16_t *in1,
const c16_t *in2,
c16_t *out,
const int size,
const int shift)
{
const simde__m256i complex_shuffle256 = simde_mm256_set_epi8(29,
28,
31,
30,
25,
24,
27,
26,
21,
20,
23,
22,
17,
16,
19,
18,
13,
12,
15,
14,
9,
8,
11,
10,
5,
4,
7,
6,
1,
0,
3,
2);
const simde__m256i conj256 = simde_mm256_set_epi16(-1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1);
int i;
// do 8 multiplications at a time
for (i = 0; i < size - 7; i += 8) {
const simde__m256i i1 = simde_mm256_loadu_epi32((simde__m256i *)(in1 + i));
const simde__m256i i2 = simde_mm256_loadu_epi32((simde__m256i *)(in2 + i));
const simde__m256i i2swap = simde_mm256_shuffle_epi8(i2, complex_shuffle256);
const simde__m256i i2conj = simde_mm256_sign_epi16(i2, conj256);
const simde__m256i re = simde_mm256_madd_epi16(i1, i2conj);
const simde__m256i im = simde_mm256_madd_epi16(i1, i2swap);
simde_mm256_storeu_si256(
(simde__m256i *)(out + i),
simde_mm256_blend_epi16(simde_mm256_srai_epi32(re, shift), simde_mm256_slli_epi32(im, 16 - shift), 0xAA));
}
if (size - i > 4) {
const simde__m128i i1 = simde_mm_loadu_epi32((simde__m128i *)(in1 + i));
const simde__m128i i2 = simde_mm_loadu_epi32((simde__m128i *)(in2 + i));
const simde__m128i i2swap = simde_mm_shuffle_epi8(i2, *(simde__m128i *)&complex_shuffle256);
const simde__m128i i2conj = simde_mm_sign_epi16(i2, *(simde__m128i *)&conj256);
const simde__m128i re = simde_mm_madd_epi16(i1, i2conj);
const simde__m128i im = simde_mm_madd_epi16(i1, i2swap);
simde_mm_storeu_si128((simde__m128i *)(out + i),
simde_mm_blend_epi16(simde_mm_srai_epi32(re, shift), simde_mm_slli_epi32(im, 16 - shift), 0xAA));
i += 4;
}
for (; i < size; i++)
out[i] = c16mulShift(in1[i], in2[i], shift);
}
/*!\fn void multadd_complex_vector_real_scalar(int16_t *x,int16_t alpha,int16_t *y,uint8_t zero_flag,uint32_t N)
This function performs componentwise multiplication and accumulation of a real scalar and a complex vector.
@param x Vector input (Q1.15) in the format |Re0 Im0|Re1 Im 1| ...
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment