diff options
-rw-r--r-- | backend/src/backend/gen_context.cpp | 193 | ||||
-rw-r--r-- | kernels/compiler_workgroup_reduce.cl | 2 | ||||
-rw-r--r-- | utests/compiler_workgroup_reduce.cpp | 5 |
3 files changed, 118 insertions, 82 deletions
diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp index 1663b6f3..4e24816c 100644 --- a/backend/src/backend/gen_context.cpp +++ b/backend/src/backend/gen_context.cpp @@ -2980,99 +2980,136 @@ namespace gbe threadDst = GenRegister::retype(threadDst, inputVal.type); threadExchangeData = GenRegister::retype(threadExchangeData, inputVal.type); - if (inputVal.hstride == GEN_HORIZONTAL_STRIDE_0) { - p->MOV(threadExchangeData, inputVal); - p->pop(); - return; - } - - /* init thread data to min/max/null values */ - p->push(); { - p->curr.execWidth = simd; - wgOpInitValue(p, threadExchangeData, wg_op); - p->MOV(resultVal, inputVal); - } p->pop(); - - GenRegister resultValSingle = resultVal; - resultValSingle.hstride = GEN_HORIZONTAL_STRIDE_0; - resultValSingle.vstride = GEN_VERTICAL_STRIDE_0; - resultValSingle.width = GEN_WIDTH_1; - - GenRegister inputValSingle = inputVal; - inputValSingle.hstride = GEN_HORIZONTAL_STRIDE_0; - inputValSingle.vstride = GEN_VERTICAL_STRIDE_0; - inputValSingle.width = GEN_WIDTH_1; - vector<GenRegister> input; vector<GenRegister> result; - /* make an array of registers for easy accesing */ - for(uint32_t i = 0; i < simd; i++){ - /* add all resultVal offset reg positions from list */ - result.push_back(resultValSingle); - input.push_back(inputValSingle); - - /* move to next position */ - resultValSingle.subnr += typeSize(resultValSingle.type); - if (resultValSingle.subnr == 32) { - resultValSingle.subnr = 0; - resultValSingle.nr++; + /* for workgroup all and any we can use simd_all/any for each thread */ + if (wg_op == ir::WORKGROUP_OP_ALL || wg_op == ir::WORKGROUP_OP_ANY) { + GenRegister constZero = GenRegister::immuw(0); + GenRegister flag01 = GenRegister::flag(0, 1); + + p->push(); + { + p->curr.predicate = GEN_PREDICATE_NONE; + p->curr.noMask = 1; + p->curr.execWidth = simd; + p->MOV(resultVal, GenRegister::immud(1)); + p->curr.execWidth = 1; + if (wg_op == ir::WORKGROUP_OP_ALL) + p->MOV(flag01, GenRegister::immw(-1)); + else + p->MOV(flag01, constZero); + + p->curr.execWidth = simd; + p->curr.noMask = 0; + + p->curr.flag = 0; + p->curr.subFlag = 1; + p->CMP(GEN_CONDITIONAL_NEQ, inputVal, constZero); + + if (p->curr.execWidth == 16) + if (wg_op == ir::WORKGROUP_OP_ALL) + p->curr.predicate = GEN_PREDICATE_ALIGN1_ALL16H; + else + p->curr.predicate = GEN_PREDICATE_ALIGN1_ANY16H; + else if (p->curr.execWidth == 8) + if (wg_op == ir::WORKGROUP_OP_ALL) + p->curr.predicate = GEN_PREDICATE_ALIGN1_ALL8H; + else + p->curr.predicate = GEN_PREDICATE_ALIGN1_ANY8H; + else + NOT_IMPLEMENTED; + p->SEL(threadDst, resultVal, constZero); + p->SEL(threadExchangeData, resultVal, constZero); } - /* move to next position */ - inputValSingle.subnr += typeSize(inputValSingle.type); - if (inputValSingle.subnr == 32) { - inputValSingle.subnr = 0; - inputValSingle.nr++; + p->pop(); + } else { + if (inputVal.hstride == GEN_HORIZONTAL_STRIDE_0) { + p->MOV(threadExchangeData, inputVal); + p->pop(); + return; } - } - - uint32_t start_i = 0; - if(wg_op == ir::WORKGROUP_OP_ANY || - wg_op == ir::WORKGROUP_OP_ALL || - wg_op == ir::WORKGROUP_OP_REDUCE_ADD || - wg_op == ir::WORKGROUP_OP_REDUCE_MIN || - wg_op == ir::WORKGROUP_OP_REDUCE_MAX || - wg_op == ir::WORKGROUP_OP_INCLUSIVE_ADD || - wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN || - wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX) { - p->MOV(result[0], input[0]); - start_i = 1; - } - else if(wg_op == ir::WORKGROUP_OP_EXCLUSIVE_ADD || - wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN || - wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX) { - p->MOV(result[1], input[0]); - start_i = 2; - } + /* init thread data to min/max/null values */ + p->push(); { + p->curr.execWidth = simd; + wgOpInitValue(p, threadExchangeData, wg_op); + p->MOV(resultVal, inputVal); + } p->pop(); + + GenRegister resultValSingle = resultVal; + resultValSingle.hstride = GEN_HORIZONTAL_STRIDE_0; + resultValSingle.vstride = GEN_VERTICAL_STRIDE_0; + resultValSingle.width = GEN_WIDTH_1; + + GenRegister inputValSingle = inputVal; + inputValSingle.hstride = GEN_HORIZONTAL_STRIDE_0; + inputValSingle.vstride = GEN_VERTICAL_STRIDE_0; + inputValSingle.width = GEN_WIDTH_1; + + + /* make an array of registers for easy accesing */ + for(uint32_t i = 0; i < simd; i++){ + /* add all resultVal offset reg positions from list */ + result.push_back(resultValSingle); + input.push_back(inputValSingle); + + /* move to next position */ + resultValSingle.subnr += typeSize(resultValSingle.type); + if (resultValSingle.subnr == 32) { + resultValSingle.subnr = 0; + resultValSingle.nr++; + } + /* move to next position */ + inputValSingle.subnr += typeSize(inputValSingle.type); + if (inputValSingle.subnr == 32) { + inputValSingle.subnr = 0; + inputValSingle.nr++; + } + } - /* algorithm workgroup */ - for (uint32_t i = start_i; i < simd; i++) - { - if(wg_op == ir::WORKGROUP_OP_ANY || - wg_op == ir::WORKGROUP_OP_ALL || - wg_op == ir::WORKGROUP_OP_REDUCE_ADD || + uint32_t start_i = 0; + if( wg_op == ir::WORKGROUP_OP_REDUCE_ADD || wg_op == ir::WORKGROUP_OP_REDUCE_MIN || - wg_op == ir::WORKGROUP_OP_REDUCE_MAX) - wgOpPerform(result[0], result[0], input[i], wg_op, p); - - else if(wg_op == ir::WORKGROUP_OP_INCLUSIVE_ADD || + wg_op == ir::WORKGROUP_OP_REDUCE_MAX || + wg_op == ir::WORKGROUP_OP_INCLUSIVE_ADD || wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN || - wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX) - wgOpPerform(result[i], result[i - 1], input[i], wg_op, p); + wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX) { + p->MOV(result[0], input[0]); + start_i = 1; + } else if(wg_op == ir::WORKGROUP_OP_EXCLUSIVE_ADD || wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN || - wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX) - wgOpPerform(result[i], result[i - 1], input[i - 1], wg_op, p); + wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX) { + p->MOV(result[1], input[0]); + start_i = 2; + } - else - GBE_ASSERT(0); + /* algorithm workgroup */ + for (uint32_t i = start_i; i < simd; i++) + { + if( wg_op == ir::WORKGROUP_OP_REDUCE_ADD || + wg_op == ir::WORKGROUP_OP_REDUCE_MIN || + wg_op == ir::WORKGROUP_OP_REDUCE_MAX) + wgOpPerform(result[0], result[0], input[i], wg_op, p); + + else if(wg_op == ir::WORKGROUP_OP_INCLUSIVE_ADD || + wg_op == ir::WORKGROUP_OP_INCLUSIVE_MIN || + wg_op == ir::WORKGROUP_OP_INCLUSIVE_MAX) + wgOpPerform(result[i], result[i - 1], input[i], wg_op, p); + + else if(wg_op == ir::WORKGROUP_OP_EXCLUSIVE_ADD || + wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MIN || + wg_op == ir::WORKGROUP_OP_EXCLUSIVE_MAX) + wgOpPerform(result[i], result[i - 1], input[i - 1], wg_op, p); + + else + GBE_ASSERT(0); + } } - if(wg_op == ir::WORKGROUP_OP_ANY || - wg_op == ir::WORKGROUP_OP_ALL || - wg_op == ir::WORKGROUP_OP_REDUCE_ADD || + if( wg_op == ir::WORKGROUP_OP_REDUCE_ADD || wg_op == ir::WORKGROUP_OP_REDUCE_MIN || wg_op == ir::WORKGROUP_OP_REDUCE_MAX) { diff --git a/kernels/compiler_workgroup_reduce.cl b/kernels/compiler_workgroup_reduce.cl index d2f37ca9..69dcea8e 100644 --- a/kernels/compiler_workgroup_reduce.cl +++ b/kernels/compiler_workgroup_reduce.cl @@ -7,7 +7,7 @@ kernel void compiler_workgroup_any(global int *src, global int *dst) { dst[get_global_id(0)] = predicate; } kernel void compiler_workgroup_all(global int *src, global int *dst) { - char val = src[get_global_id(0)]; + int val = src[get_global_id(0)]; int predicate = work_group_all(val); dst[get_global_id(0)] = predicate; } diff --git a/utests/compiler_workgroup_reduce.cpp b/utests/compiler_workgroup_reduce.cpp index f185666f..4003bf88 100644 --- a/utests/compiler_workgroup_reduce.cpp +++ b/utests/compiler_workgroup_reduce.cpp @@ -38,7 +38,7 @@ static void compute_expected(WG_FUNCTION wg_func, { T wg_predicate = input[0]; for(uint32_t i = 1; i < WG_LOCAL_SIZE; i++) - wg_predicate = (int)wg_predicate | (int)input[i]; + wg_predicate = (int)wg_predicate || (int)input[i]; for(uint32_t i = 0; i < WG_LOCAL_SIZE; i++) expected[i] = wg_predicate; } @@ -46,7 +46,7 @@ static void compute_expected(WG_FUNCTION wg_func, { T wg_predicate = input[0]; for(uint32_t i = 1; i < WG_LOCAL_SIZE; i++) - wg_predicate = (int)wg_predicate & (int)input[i]; + wg_predicate = (int)wg_predicate && (int)input[i]; for(uint32_t i = 0; i < WG_LOCAL_SIZE; i++) expected[i] = wg_predicate; } @@ -112,7 +112,6 @@ static void generate_data(WG_FUNCTION wg_func, /* add trailing random bits, tests GENERAL cases */ input[gid + lid] += (rand() % 112); /* always last bit is 1, ideal test ALL/ANY */ - input[gid + lid] = (T)((long)input[gid + lid] | (long)1); } else { input[gid + lid] += rand(); input[gid + lid] += rand() / ((float)RAND_MAX + 1); |